Implements the Replicate API backend (FLUX schnell / FLUX dev) per ImaGen issue #3: - internal/backend/replicate.go — Backend adapter. Supports model refs as "owner/name" (uses /v1/models/{owner}/{name}/predictions) and "owner/name:hash" (uses /v1/predictions with explicit version). Polls /v1/predictions/{id} every 500ms with model-aware timeout (60s schnell, 120s dev). Resilience: 401 names api_token_env, 429 with exp backoff up to 3 retries (honours Retry-After), 5xx retries once, image download retries once on transient failure. - internal/backend/replicate_pricing.go — hardcoded per-image USD rates for known FLUX models, snapshotted from replicate.com/pricing with a refresh TODO. - internal/backend/replicate_test.go — mocked-HTTP unit tests covering happy path (model + version-pinned), 401, 429 retry policy, failed prediction, poll timeout, image-download retry, ctx cancel, BackendOpts passthrough, default_steps, aspect-ratio reduction, sha256 prompt hash. - internal/usage/usage.go — Supabase REST sink + read-side query for mai.imagen_usage. Adapter writes are best-effort: failures warn but the image still lands. - cmd/imagen/usage.go — `imagen usage [--since DATE] [--raw]` reads the table and prints a tab-aligned grouped or raw table with totals. - cmd/imagen/backends.go — instances of type=replicate now report "ok" or "not configured (set REPLICATE_API_TOKEN)" depending on env. - internal/config/config.go — sample adds flux-schnell-replicate + flux-dev-replicate; default_backend stays flux-schnell-local. - Supabase migration mai.imagen_usage (id, created_at, backend, model, seed, prompt_hash, latency_ms, cost_usd_estimate, caller) + indexes on (created_at DESC) and (caller). The raw prompt is never stored. Caller identity resolves from MAI_FROM_ID, then the tmux pane's @mai-name option, mirroring the maimcp identity logic. Prompt hash is sha256 of the user-facing prompt; raw prompt never reaches the table.
568 lines
16 KiB
Go
568 lines
16 KiB
Go
package backend
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ReplicateType is the type-name adapters register under for Replicate
|
|
// instances.
|
|
const ReplicateType = "replicate"
|
|
|
|
// Replicate is the Replicate API adapter. It speaks the public REST API
|
|
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
|
|
// to submit, then polls /v1/predictions/{id} until the prediction
|
|
// settles, then downloads the produced image.
|
|
//
|
|
// Concurrency: a single Replicate is safe to share across goroutines.
|
|
type Replicate struct {
|
|
instance string
|
|
|
|
apiBase string
|
|
apiToken string
|
|
tokenEnv string
|
|
model string // "owner/name" or "owner/name:version-hash"
|
|
owner string
|
|
name string
|
|
version string // optional; empty means "use the model-based predictions endpoint"
|
|
defaultSteps int
|
|
defaultAspect string
|
|
|
|
httpClient *http.Client
|
|
pollInterval time.Duration
|
|
pollTimeout time.Duration
|
|
|
|
// Hooks for tests; production paths use the package-level defaults.
|
|
randSeed func() int64
|
|
initialBackoff time.Duration
|
|
|
|
// Sink is where successful generations are recorded for cost-tracking.
|
|
// nil means "do not record". The framework wires this up in the CLI;
|
|
// adapter tests inject a fake.
|
|
Sink UsageSink
|
|
}
|
|
|
|
// UsageSink writes one row per successful generation. Implementations
|
|
// should treat write failures as warnings, not errors — the image has
|
|
// already landed on disk; failing the call would lose the artefact.
|
|
type UsageSink interface {
|
|
Record(ctx context.Context, row UsageRow) error
|
|
}
|
|
|
|
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
|
|
// prompt itself is intentionally NOT included — only the sha256 hash.
|
|
type UsageRow struct {
|
|
Backend string
|
|
Model string
|
|
Seed *int64
|
|
PromptHash string
|
|
LatencyMs int
|
|
CostUSDEstimate *float64
|
|
Caller string
|
|
}
|
|
|
|
// NewReplicate is the registry constructor. cfg is the adapter's slice
|
|
// of imagen.yaml.
|
|
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
|
|
if name == "" {
|
|
return nil, fmt.Errorf("replicate: empty instance name")
|
|
}
|
|
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
|
|
model := getString(cfg, "model", "")
|
|
if model == "" {
|
|
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
|
|
}
|
|
owner, modelName, version, err := parseModelRef(model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
|
|
}
|
|
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
|
|
pollTimeout := timeoutForModel(modelName)
|
|
r := &Replicate{
|
|
instance: name,
|
|
apiBase: apiBase,
|
|
tokenEnv: tokenEnv,
|
|
apiToken: os.Getenv(tokenEnv),
|
|
model: model,
|
|
owner: owner,
|
|
name: modelName,
|
|
version: version,
|
|
defaultSteps: getInt(cfg, "default_steps", 0),
|
|
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
|
|
httpClient: &http.Client{Timeout: 30 * time.Second},
|
|
pollInterval: 500 * time.Millisecond,
|
|
pollTimeout: pollTimeout,
|
|
randSeed: cryptoSeed,
|
|
initialBackoff: time.Second,
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
// Name returns the instance name from imagen.yaml.
|
|
func (r *Replicate) Name() string { return r.instance }
|
|
|
|
// Generate submits one prediction and returns the resulting PNG.
|
|
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
|
|
if r.apiToken == "" {
|
|
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
|
|
}
|
|
|
|
width := orDefaultInt(req.Width, 1024)
|
|
height := orDefaultInt(req.Height, 1024)
|
|
aspect := computeAspectRatio(width, height, r.defaultAspect)
|
|
|
|
steps := orDefaultInt(req.Steps, r.defaultSteps)
|
|
|
|
seed := req.Seed
|
|
if seed == 0 {
|
|
seed = r.randSeed()
|
|
}
|
|
|
|
input := map[string]any{
|
|
"prompt": req.Prompt,
|
|
"aspect_ratio": aspect,
|
|
"num_outputs": 1,
|
|
"output_format": "png",
|
|
"seed": seed,
|
|
}
|
|
if steps > 0 {
|
|
input["num_inference_steps"] = steps
|
|
}
|
|
if req.NegativePrompt != "" {
|
|
input["negative_prompt"] = req.NegativePrompt
|
|
}
|
|
maps.Copy(input, req.BackendOpts)
|
|
|
|
start := time.Now()
|
|
pred, err := r.submitPrediction(ctx, input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
final, err := r.waitForCompletion(ctx, pred.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
|
|
}
|
|
imgURL, err := pickFirstOutputURL(final.Output)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
|
|
}
|
|
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
latencyMs := int(time.Since(start).Milliseconds())
|
|
|
|
predictTime := final.Metrics.PredictTime
|
|
costEst, costKnown := replicatePerImageUSD(r.model)
|
|
|
|
meta := map[string]any{
|
|
"backend": r.instance,
|
|
"backend_type": ReplicateType,
|
|
"model": r.model,
|
|
"model_version": final.Version,
|
|
"prediction_id": pred.ID,
|
|
"seed": seed,
|
|
"width": width,
|
|
"height": height,
|
|
"aspect_ratio": aspect,
|
|
"latency_ms": int64(latencyMs),
|
|
}
|
|
if predictTime > 0 {
|
|
meta["predict_time_seconds"] = predictTime
|
|
}
|
|
if costKnown {
|
|
meta["cost_usd_estimate"] = costEst
|
|
}
|
|
|
|
if r.Sink != nil {
|
|
row := UsageRow{
|
|
Backend: r.instance,
|
|
Model: r.model,
|
|
PromptHash: hashPrompt(req.Prompt),
|
|
LatencyMs: latencyMs,
|
|
Caller: ResolveCaller(),
|
|
}
|
|
row.Seed = new(int64)
|
|
*row.Seed = seed
|
|
if costKnown {
|
|
c := costEst
|
|
row.CostUSDEstimate = &c
|
|
}
|
|
if err := r.Sink.Record(ctx, row); err != nil {
|
|
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
|
|
}
|
|
}
|
|
|
|
return &Result{
|
|
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
|
|
MimeType: mime,
|
|
Metadata: meta,
|
|
}, nil
|
|
}
|
|
|
|
// replicatePrediction is what the REST API returns under /v1/predictions
|
|
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
|
|
// declared.
|
|
type replicatePrediction struct {
|
|
ID string `json:"id"`
|
|
Status string `json:"status"`
|
|
Version string `json:"version"`
|
|
Error json.RawMessage `json:"error"`
|
|
Output json.RawMessage `json:"output"`
|
|
Metrics struct {
|
|
PredictTime float64 `json:"predict_time"`
|
|
} `json:"metrics"`
|
|
}
|
|
|
|
// submitPrediction creates a prediction and returns it. Picks the
|
|
// model-based endpoint when no version was given (recommended for
|
|
// Replicate's official models), and the legacy /v1/predictions otherwise.
|
|
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
|
|
var (
|
|
reqURL string
|
|
body []byte
|
|
err error
|
|
)
|
|
if r.version == "" {
|
|
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
|
|
body, err = json.Marshal(map[string]any{"input": input})
|
|
} else {
|
|
reqURL = r.apiBase + "/v1/predictions"
|
|
body, err = json.Marshal(map[string]any{
|
|
"version": r.version,
|
|
"input": input,
|
|
})
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
|
|
}
|
|
|
|
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var pred replicatePrediction
|
|
if err := json.Unmarshal(respBody, &pred); err != nil {
|
|
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
|
|
}
|
|
if pred.ID == "" {
|
|
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
|
|
}
|
|
return &pred, nil
|
|
}
|
|
|
|
// waitForCompletion polls /v1/predictions/{id} until the status is a
|
|
// terminal value (succeeded, failed, canceled) or the timeout fires.
|
|
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
|
|
deadline := time.Now().Add(r.pollTimeout)
|
|
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
|
|
start := time.Now()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
default:
|
|
}
|
|
if time.Now().After(deadline) {
|
|
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
|
|
}
|
|
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var pred replicatePrediction
|
|
if err := json.Unmarshal(body, &pred); err != nil {
|
|
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
|
|
}
|
|
switch pred.Status {
|
|
case "succeeded":
|
|
return &pred, nil
|
|
case "failed":
|
|
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
|
|
case "canceled":
|
|
return nil, fmt.Errorf("status=canceled")
|
|
case "starting", "processing", "":
|
|
default:
|
|
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(r.pollInterval):
|
|
}
|
|
}
|
|
}
|
|
|
|
// doWithRetry executes a Replicate API request and applies the resilience
|
|
// policy: 401 surfaces a clean message naming the env var; 429 retries
|
|
// with exponential backoff up to three times; 5xx retries once; other
|
|
// errors surface unchanged.
|
|
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
|
|
const max429Retries = 3
|
|
backoff := r.initialBackoff
|
|
if backoff <= 0 {
|
|
backoff = time.Second
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; ; attempt++ {
|
|
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Token "+r.apiToken)
|
|
if body != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
resp, err := r.httpClient.Do(req)
|
|
if err != nil {
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, err
|
|
}
|
|
if attempt >= 1 {
|
|
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
|
|
}
|
|
lastErr = err
|
|
if !sleepCtx(ctx, backoff) {
|
|
return nil, ctx.Err()
|
|
}
|
|
continue
|
|
}
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
_ = resp.Body.Close()
|
|
|
|
switch {
|
|
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
|
return respBody, nil
|
|
case resp.StatusCode == http.StatusUnauthorized:
|
|
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
|
|
case resp.StatusCode == http.StatusTooManyRequests:
|
|
if attempt >= max429Retries {
|
|
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
|
|
}
|
|
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
|
|
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
|
|
if !sleepCtx(ctx, wait) {
|
|
return nil, ctx.Err()
|
|
}
|
|
backoff *= 2
|
|
continue
|
|
case resp.StatusCode >= 500:
|
|
if attempt >= 1 {
|
|
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
|
|
}
|
|
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
|
|
if !sleepCtx(ctx, backoff) {
|
|
return nil, ctx.Err()
|
|
}
|
|
continue
|
|
default:
|
|
_ = lastErr
|
|
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
|
|
}
|
|
}
|
|
}
|
|
|
|
// fetchImage downloads the rendered image from the Replicate-provided
|
|
// CDN URL. One retry on a generic network error.
|
|
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
|
|
var lastErr error
|
|
for attempt := range 2 {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
resp, err := r.httpClient.Do(req)
|
|
if err != nil {
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, "", err
|
|
}
|
|
lastErr = err
|
|
continue
|
|
}
|
|
body, readErr := io.ReadAll(resp.Body)
|
|
_ = resp.Body.Close()
|
|
if readErr != nil {
|
|
lastErr = readErr
|
|
continue
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
if resp.StatusCode >= 500 && attempt == 0 {
|
|
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
|
|
continue
|
|
}
|
|
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
|
|
}
|
|
mime := resp.Header.Get("Content-Type")
|
|
if mime == "" {
|
|
mime = "image/png"
|
|
}
|
|
return body, mime, nil
|
|
}
|
|
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
|
|
}
|
|
|
|
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
|
|
func parseModelRef(ref string) (owner, name, version string, err error) {
|
|
rest := ref
|
|
if i := strings.IndexByte(rest, ':'); i >= 0 {
|
|
version = rest[i+1:]
|
|
rest = rest[:i]
|
|
}
|
|
parts := strings.SplitN(rest, "/", 2)
|
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
|
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
|
|
}
|
|
return parts[0], parts[1], version, nil
|
|
}
|
|
|
|
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
|
|
// than schnell; everything else gets the dev timeout for safety.
|
|
func timeoutForModel(name string) time.Duration {
|
|
switch strings.ToLower(name) {
|
|
case "flux-schnell":
|
|
return 60 * time.Second
|
|
default:
|
|
return 120 * time.Second
|
|
}
|
|
}
|
|
|
|
// pickFirstOutputURL extracts the first image URL from the output field.
|
|
// Replicate returns either a single string or an array of strings for image
|
|
// models — we accept both.
|
|
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
|
|
if len(raw) == 0 {
|
|
return "", fmt.Errorf("output is empty")
|
|
}
|
|
var s string
|
|
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
|
|
return s, nil
|
|
}
|
|
var arr []string
|
|
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
|
|
return arr[0], nil
|
|
}
|
|
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
|
|
}
|
|
|
|
// computeAspectRatio reduces width:height to a Replicate-supported aspect
|
|
// ratio when the reduction lands on one of the canonical values; otherwise
|
|
// returns fallback.
|
|
func computeAspectRatio(w, h int, fallback string) string {
|
|
if w <= 0 || h <= 0 {
|
|
return fallback
|
|
}
|
|
g := gcd(w, h)
|
|
a, b := w/g, h/g
|
|
s := fmt.Sprintf("%d:%d", a, b)
|
|
if isReplicateAspectRatio(s) {
|
|
return s
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func isReplicateAspectRatio(s string) bool {
|
|
switch s {
|
|
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func gcd(a, b int) int {
|
|
for b != 0 {
|
|
a, b = b, a%b
|
|
}
|
|
if a < 0 {
|
|
return -a
|
|
}
|
|
return a
|
|
}
|
|
|
|
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
|
|
// is intentionally never written to the cost-tracking table.
|
|
func hashPrompt(p string) string {
|
|
sum := sha256.Sum256([]byte(p))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
// ResolveCaller returns the agent identity the cost-tracking row is
|
|
// attributed to. Order of resolution mirrors the maimcp identity logic:
|
|
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
|
|
// "unknown".
|
|
func ResolveCaller() string {
|
|
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
|
|
return v
|
|
}
|
|
if pane := os.Getenv("TMUX_PANE"); pane != "" {
|
|
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
|
|
if err == nil {
|
|
if name := strings.TrimSpace(string(out)); name != "" {
|
|
return name
|
|
}
|
|
}
|
|
}
|
|
return "unknown"
|
|
}
|
|
|
|
// backoffFor429 honours a Retry-After header (in seconds) when present
|
|
// and within reason, otherwise falls back to the caller's backoff.
|
|
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
|
|
if retryAfter == "" {
|
|
return fallback
|
|
}
|
|
d, err := time.ParseDuration(retryAfter + "s")
|
|
if err != nil || d <= 0 {
|
|
return fallback
|
|
}
|
|
if d > 30*time.Second {
|
|
return 30 * time.Second
|
|
}
|
|
return d
|
|
}
|
|
|
|
func sleepCtx(ctx context.Context, d time.Duration) bool {
|
|
if d <= 0 {
|
|
return true
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
case <-time.After(d):
|
|
return true
|
|
}
|
|
}
|
|
|
|
func bytesReader(b []byte) io.Reader {
|
|
if b == nil {
|
|
return nil
|
|
}
|
|
return bytes.NewReader(b)
|
|
}
|
|
|
|
// shortPath strips the host so error messages don't leak the API base.
|
|
func shortPath(u string) string {
|
|
if i := strings.Index(u, "/v1/"); i >= 0 {
|
|
return u[i:]
|
|
}
|
|
return u
|
|
}
|
|
|
|
func init() {
|
|
Register(ReplicateType, NewReplicate)
|
|
}
|