diff --git a/cmd/imagen/backends.go b/cmd/imagen/backends.go index c8d01ad..97bd1d3 100644 --- a/cmd/imagen/backends.go +++ b/cmd/imagen/backends.go @@ -10,6 +10,27 @@ import ( "mgit.msbls.de/m/ImaGen/internal/config" ) +// instanceStatus checks adapter-specific preconditions (e.g. the +// Replicate API token env var being set) and returns a short +// user-facing status string. +func instanceStatus(spec config.BackendSpec) string { + if !backend.Default.Has(spec.Type) { + return fmt.Sprintf("type %q not compiled in", spec.Type) + } + switch spec.Type { + case backend.ReplicateType: + envName, _ := spec.Raw["api_token_env"].(string) + if envName == "" { + envName = "REPLICATE_API_TOKEN" + } + if os.Getenv(envName) == "" { + return fmt.Sprintf("not configured (set %s)", envName) + } + return "ok" + } + return "registered" +} + func runBackends(args []string) error { fs := flag.NewFlagSet("backends", flag.ContinueOnError) var configPath string @@ -27,10 +48,7 @@ func runBackends(args []string) error { fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS") if cfg != nil { for name, spec := range cfg.Backends { - status := "registered" - if !backend.Default.Has(spec.Type) { - status = fmt.Sprintf("type %q not compiled in", spec.Type) - } + status := instanceStatus(spec) marker := "" if name == cfg.DefaultBackend { marker = " (default)" diff --git a/cmd/imagen/generate.go b/cmd/imagen/generate.go index 19b596c..00ba642 100644 --- a/cmd/imagen/generate.go +++ b/cmd/imagen/generate.go @@ -13,6 +13,7 @@ import ( "mgit.msbls.de/m/ImaGen/internal/output" "mgit.msbls.de/m/ImaGen/internal/preview" "mgit.msbls.de/m/ImaGen/internal/prompt" + "mgit.msbls.de/m/ImaGen/internal/usage" ) func runGenerate(ctx context.Context, args []string) error { @@ -81,6 +82,7 @@ func runGenerate(ctx context.Context, args []string) error { if err != nil { return err } + attachUsageSink(be) finalPrompt, err := prompt.Apply(rawPrompt, style) if err != nil { @@ -219,6 +221,21 @@ func parseSize(s string) (int, int, error) { return w, h, nil } +// attachUsageSink wires a Supabase cost-tracking sink into the backend +// when it accepts one and the env is configured. Adapters that record +// usage expose a public Sink field of type backend.UsageSink. +func attachUsageSink(be backend.Backend) { + r, ok := be.(*backend.Replicate) + if !ok { + return + } + sink, ok := usage.NewSupabaseSinkFromEnv() + if !ok { + return + } + r.Sink = sink +} + func buildBackend(cfg *config.Config, name string) (backend.Backend, error) { if cfg != nil { spec, ok := cfg.Backends[name] diff --git a/cmd/imagen/main.go b/cmd/imagen/main.go index d62c794..e5d227a 100644 --- a/cmd/imagen/main.go +++ b/cmd/imagen/main.go @@ -14,7 +14,7 @@ import ( _ "mgit.msbls.de/m/ImaGen/internal/backend" ) -const usage = `imagen — model-agnostic image generation +const helpText = `imagen — model-agnostic image generation Usage: imagen generate [flags] generate one image @@ -22,6 +22,7 @@ Usage: imagen config init print a sample imagen.yaml on stdout imagen config validate validate the active config imagen serve [--addr :8080] (stub) start the HTTP server + imagen usage [--since DATE] show cost-tracking rows imagen version print version imagen help show this help @@ -33,7 +34,7 @@ var Version = "dev" func main() { if len(os.Args) < 2 { - fmt.Fprint(os.Stderr, usage) + fmt.Fprint(os.Stderr, helpText) os.Exit(2) } ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) @@ -50,12 +51,14 @@ func main() { err = runConfig(args) case "serve": err = runServe(args) + case "usage": + err = runUsage(ctx, args) case "version", "-v", "--version": fmt.Println(Version) case "help", "-h", "--help": - fmt.Print(usage) + fmt.Print(helpText) default: - fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage) + fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText) os.Exit(2) } if err != nil { diff --git a/cmd/imagen/usage.go b/cmd/imagen/usage.go new file mode 100644 index 0000000..8dfa0c4 --- /dev/null +++ b/cmd/imagen/usage.go @@ -0,0 +1,189 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "sort" + "strings" + "text/tabwriter" + "time" + + "mgit.msbls.de/m/ImaGen/internal/usage" +) + +// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage +// via Supabase REST and prints a tab-aligned table grouped by week + +// backend + model + caller, with totals at the bottom. +func runUsage(ctx context.Context, args []string) error { + fs := flag.NewFlagSet("usage", flag.ContinueOnError) + var ( + since string + raw bool + ) + fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date") + fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped") + fs.Usage = func() { + fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]") + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return err + } + + var sinceT time.Time + if since != "" { + t, err := time.Parse("2006-01-02", since) + if err != nil { + return userErr("--since must be YYYY-MM-DD: %v", err) + } + sinceT = t + } + + sink, ok := usage.NewSupabaseSinkFromEnv() + if !ok { + return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage") + } + rows, err := sink.Query(ctx, sinceT) + if err != nil { + return err + } + + if raw { + printRawRows(rows) + return nil + } + printGroupedRows(rows) + return nil +} + +func printRawRows(rows []usage.Row) { + tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD") + var totalCost float64 + for _, r := range rows { + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", + r.CreatedAt.Local().Format("2006-01-02 15:04"), + r.Backend, + r.Model, + derefString(r.Caller), + intOrDash(r.LatencyMs), + costOrDash(r.CostUSDEstimate), + ) + if r.CostUSDEstimate != nil { + totalCost += *r.CostUSDEstimate + } + } + fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost) + _ = tw.Flush() +} + +type group struct { + week string + backend string + model string + caller string + count int + cost float64 + costSet bool +} + +type groupKey struct { + week, backend, model, caller string +} + +func printGroupedRows(rows []usage.Row) { + groups := map[groupKey]*group{} + for _, r := range rows { + caller := derefString(r.Caller) + k := groupKey{ + week: weekStart(r.CreatedAt).Format("2006-01-02"), + backend: r.Backend, + model: r.Model, + caller: caller, + } + g, ok := groups[k] + if !ok { + g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller} + groups[k] = g + } + g.count++ + if r.CostUSDEstimate != nil { + g.cost += *r.CostUSDEstimate + g.costSet = true + } + } + + keys := make([]groupKey, 0, len(groups)) + for k := range groups { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].week != keys[j].week { + return keys[i].week > keys[j].week // newest first + } + if keys[i].backend != keys[j].backend { + return keys[i].backend < keys[j].backend + } + if keys[i].model != keys[j].model { + return keys[i].model < keys[j].model + } + return keys[i].caller < keys[j].caller + }) + + tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD") + var totalCount int + var totalCost float64 + for _, k := range keys { + g := groups[k] + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n", + g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet), + ) + totalCount += g.count + totalCost += g.cost + } + fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost) + _ = tw.Flush() +} + +// weekStart returns the Monday of the week containing t (UTC). +func weekStart(t time.Time) time.Time { + t = t.UTC() + wd := int(t.Weekday()) + if wd == 0 { + wd = 7 // shift Sunday to end-of-week + } + delta := time.Duration(wd-1) * -24 * time.Hour + d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) + return d.Add(delta) +} + +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} + +func intOrDash(p *int) string { + if p == nil { + return "-" + } + return fmt.Sprintf("%d", *p) +} + +func costOrDash(p *float64) string { + if p == nil { + return "-" + } + return fmt.Sprintf("%.4f", *p) +} + +func costStr(v float64, set bool) string { + if !set { + return "-" + } + return strings.TrimSpace(fmt.Sprintf("%.4f", v)) +} diff --git a/docs/usage.md b/docs/usage.md index 3f347d0..01fd1ea 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -91,3 +91,26 @@ API-backed adapters read tokens from env vars referenced by the config export REPLICATE_API_TOKEN=... imagen generate "a cat" --backend flux-dev-replicate ``` + +## Cost-tracking (Replicate) + +Successful generations through the Replicate adapter write one row to +`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost +estimate, prompt sha256 hash (never the prompt itself), and the caller +identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`). + +The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are +unset, or the database write fails, the image still lands and the CLI +prints a warning to stderr. + +Inspect spend: + +```sh +imagen usage # all rows, grouped by week + backend + model + caller +imagen usage --since 2026-05-01 # only rows on/after a UTC date +imagen usage --since 2026-05-01 --raw +``` + +Per-model rates live in `internal/backend/replicate_pricing.go` — they +are snapshotted from and refreshed on a +quarterly cadence. diff --git a/internal/backend/replicate.go b/internal/backend/replicate.go new file mode 100644 index 0000000..8ed5034 --- /dev/null +++ b/internal/backend/replicate.go @@ -0,0 +1,567 @@ +package backend + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "os" + "os/exec" + "strings" + "time" +) + +// ReplicateType is the type-name adapters register under for Replicate +// instances. +const ReplicateType = "replicate" + +// Replicate is the Replicate API adapter. It speaks the public REST API +// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions +// to submit, then polls /v1/predictions/{id} until the prediction +// settles, then downloads the produced image. +// +// Concurrency: a single Replicate is safe to share across goroutines. +type Replicate struct { + instance string + + apiBase string + apiToken string + tokenEnv string + model string // "owner/name" or "owner/name:version-hash" + owner string + name string + version string // optional; empty means "use the model-based predictions endpoint" + defaultSteps int + defaultAspect string + + httpClient *http.Client + pollInterval time.Duration + pollTimeout time.Duration + + // Hooks for tests; production paths use the package-level defaults. + randSeed func() int64 + initialBackoff time.Duration + + // Sink is where successful generations are recorded for cost-tracking. + // nil means "do not record". The framework wires this up in the CLI; + // adapter tests inject a fake. + Sink UsageSink +} + +// UsageSink writes one row per successful generation. Implementations +// should treat write failures as warnings, not errors — the image has +// already landed on disk; failing the call would lose the artefact. +type UsageSink interface { + Record(ctx context.Context, row UsageRow) error +} + +// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the +// prompt itself is intentionally NOT included — only the sha256 hash. +type UsageRow struct { + Backend string + Model string + Seed *int64 + PromptHash string + LatencyMs int + CostUSDEstimate *float64 + Caller string +} + +// NewReplicate is the registry constructor. cfg is the adapter's slice +// of imagen.yaml. +func NewReplicate(name string, cfg map[string]any) (Backend, error) { + if name == "" { + return nil, fmt.Errorf("replicate: empty instance name") + } + tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN") + model := getString(cfg, "model", "") + if model == "" { + return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name) + } + owner, modelName, version, err := parseModelRef(model) + if err != nil { + return nil, fmt.Errorf("replicate[%s]: %w", name, err) + } + apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/") + pollTimeout := timeoutForModel(modelName) + r := &Replicate{ + instance: name, + apiBase: apiBase, + tokenEnv: tokenEnv, + apiToken: os.Getenv(tokenEnv), + model: model, + owner: owner, + name: modelName, + version: version, + defaultSteps: getInt(cfg, "default_steps", 0), + defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"), + httpClient: &http.Client{Timeout: 30 * time.Second}, + pollInterval: 500 * time.Millisecond, + pollTimeout: pollTimeout, + randSeed: cryptoSeed, + initialBackoff: time.Second, + } + return r, nil +} + +// Name returns the instance name from imagen.yaml. +func (r *Replicate) Name() string { return r.instance } + +// Generate submits one prediction and returns the resulting PNG. +func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) { + if r.apiToken == "" { + return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv) + } + + width := orDefaultInt(req.Width, 1024) + height := orDefaultInt(req.Height, 1024) + aspect := computeAspectRatio(width, height, r.defaultAspect) + + steps := orDefaultInt(req.Steps, r.defaultSteps) + + seed := req.Seed + if seed == 0 { + seed = r.randSeed() + } + + input := map[string]any{ + "prompt": req.Prompt, + "aspect_ratio": aspect, + "num_outputs": 1, + "output_format": "png", + "seed": seed, + } + if steps > 0 { + input["num_inference_steps"] = steps + } + if req.NegativePrompt != "" { + input["negative_prompt"] = req.NegativePrompt + } + maps.Copy(input, req.BackendOpts) + + start := time.Now() + pred, err := r.submitPrediction(ctx, input) + if err != nil { + return nil, err + } + final, err := r.waitForCompletion(ctx, pred.ID) + if err != nil { + return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err) + } + imgURL, err := pickFirstOutputURL(final.Output) + if err != nil { + return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err) + } + imgBytes, mime, err := r.fetchImage(ctx, imgURL) + if err != nil { + return nil, err + } + latencyMs := int(time.Since(start).Milliseconds()) + + predictTime := final.Metrics.PredictTime + costEst, costKnown := replicatePerImageUSD(r.model) + + meta := map[string]any{ + "backend": r.instance, + "backend_type": ReplicateType, + "model": r.model, + "model_version": final.Version, + "prediction_id": pred.ID, + "seed": seed, + "width": width, + "height": height, + "aspect_ratio": aspect, + "latency_ms": int64(latencyMs), + } + if predictTime > 0 { + meta["predict_time_seconds"] = predictTime + } + if costKnown { + meta["cost_usd_estimate"] = costEst + } + + if r.Sink != nil { + row := UsageRow{ + Backend: r.instance, + Model: r.model, + PromptHash: hashPrompt(req.Prompt), + LatencyMs: latencyMs, + Caller: ResolveCaller(), + } + row.Seed = new(int64) + *row.Seed = seed + if costKnown { + c := costEst + row.CostUSDEstimate = &c + } + if err := r.Sink.Record(ctx, row); err != nil { + fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err) + } + } + + return &Result{ + ImageReader: io.NopCloser(bytes.NewReader(imgBytes)), + MimeType: mime, + Metadata: meta, + }, nil +} + +// replicatePrediction is what the REST API returns under /v1/predictions +// and /v1/models/{owner}/{name}/predictions. Only the fields we use are +// declared. +type replicatePrediction struct { + ID string `json:"id"` + Status string `json:"status"` + Version string `json:"version"` + Error json.RawMessage `json:"error"` + Output json.RawMessage `json:"output"` + Metrics struct { + PredictTime float64 `json:"predict_time"` + } `json:"metrics"` +} + +// submitPrediction creates a prediction and returns it. Picks the +// model-based endpoint when no version was given (recommended for +// Replicate's official models), and the legacy /v1/predictions otherwise. +func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) { + var ( + reqURL string + body []byte + err error + ) + if r.version == "" { + reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name) + body, err = json.Marshal(map[string]any{"input": input}) + } else { + reqURL = r.apiBase + "/v1/predictions" + body, err = json.Marshal(map[string]any{ + "version": r.version, + "input": input, + }) + } + if err != nil { + return nil, fmt.Errorf("replicate: marshal prediction body: %w", err) + } + + respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body) + if err != nil { + return nil, err + } + var pred replicatePrediction + if err := json.Unmarshal(respBody, &pred); err != nil { + return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody)) + } + if pred.ID == "" { + return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody)) + } + return &pred, nil +} + +// waitForCompletion polls /v1/predictions/{id} until the status is a +// terminal value (succeeded, failed, canceled) or the timeout fires. +func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) { + deadline := time.Now().Add(r.pollTimeout) + getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id) + start := time.Now() + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if time.Now().After(deadline) { + return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond)) + } + body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil) + if err != nil { + return nil, err + } + var pred replicatePrediction + if err := json.Unmarshal(body, &pred); err != nil { + return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body)) + } + switch pred.Status { + case "succeeded": + return &pred, nil + case "failed": + return nil, fmt.Errorf("status=failed: %s", snip(pred.Error)) + case "canceled": + return nil, fmt.Errorf("status=canceled") + case "starting", "processing", "": + default: + return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body)) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(r.pollInterval): + } + } +} + +// doWithRetry executes a Replicate API request and applies the resilience +// policy: 401 surfaces a clean message naming the env var; 429 retries +// with exponential backoff up to three times; 5xx retries once; other +// errors surface unchanged. +func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) { + const max429Retries = 3 + backoff := r.initialBackoff + if backoff <= 0 { + backoff = time.Second + } + + var lastErr error + for attempt := 0; ; attempt++ { + req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Token "+r.apiToken) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := r.httpClient.Do(req) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + if attempt >= 1 { + return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err) + } + lastErr = err + if !sleepCtx(ctx, backoff) { + return nil, ctx.Err() + } + continue + } + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + switch { + case resp.StatusCode >= 200 && resp.StatusCode < 300: + return respBody, nil + case resp.StatusCode == http.StatusUnauthorized: + return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv) + case resp.StatusCode == http.StatusTooManyRequests: + if attempt >= max429Retries { + return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody)) + } + wait := backoffFor429(resp.Header.Get("Retry-After"), backoff) + lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody)) + if !sleepCtx(ctx, wait) { + return nil, ctx.Err() + } + backoff *= 2 + continue + case resp.StatusCode >= 500: + if attempt >= 1 { + return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody)) + } + lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody)) + if !sleepCtx(ctx, backoff) { + return nil, ctx.Err() + } + continue + default: + _ = lastErr + return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody)) + } + } +} + +// fetchImage downloads the rendered image from the Replicate-provided +// CDN URL. One retry on a generic network error. +func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) { + var lastErr error + for attempt := range 2 { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil) + if err != nil { + return nil, "", err + } + resp, err := r.httpClient.Do(req) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, "", err + } + lastErr = err + continue + } + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if readErr != nil { + lastErr = readErr + continue + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if resp.StatusCode >= 500 && attempt == 0 { + lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body)) + continue + } + return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body)) + } + mime := resp.Header.Get("Content-Type") + if mime == "" { + mime = "image/png" + } + return body, mime, nil + } + return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr) +} + +// parseModelRef accepts "owner/name" or "owner/name:version-hash". +func parseModelRef(ref string) (owner, name, version string, err error) { + rest := ref + if i := strings.IndexByte(rest, ':'); i >= 0 { + version = rest[i+1:] + rest = rest[:i] + } + parts := strings.SplitN(rest, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref) + } + return parts[0], parts[1], version, nil +} + +// timeoutForModel picks the polling timeout. FLUX dev takes notably longer +// than schnell; everything else gets the dev timeout for safety. +func timeoutForModel(name string) time.Duration { + switch strings.ToLower(name) { + case "flux-schnell": + return 60 * time.Second + default: + return 120 * time.Second + } +} + +// pickFirstOutputURL extracts the first image URL from the output field. +// Replicate returns either a single string or an array of strings for image +// models — we accept both. +func pickFirstOutputURL(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", fmt.Errorf("output is empty") + } + var s string + if err := json.Unmarshal(raw, &s); err == nil && s != "" { + return s, nil + } + var arr []string + if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" { + return arr[0], nil + } + return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw)) +} + +// computeAspectRatio reduces width:height to a Replicate-supported aspect +// ratio when the reduction lands on one of the canonical values; otherwise +// returns fallback. +func computeAspectRatio(w, h int, fallback string) string { + if w <= 0 || h <= 0 { + return fallback + } + g := gcd(w, h) + a, b := w/g, h/g + s := fmt.Sprintf("%d:%d", a, b) + if isReplicateAspectRatio(s) { + return s + } + return fallback +} + +func isReplicateAspectRatio(s string) bool { + switch s { + case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21": + return true + } + return false +} + +func gcd(a, b int) int { + for b != 0 { + a, b = b, a%b + } + if a < 0 { + return -a + } + return a +} + +// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt +// is intentionally never written to the cost-tracking table. +func hashPrompt(p string) string { + sum := sha256.Sum256([]byte(p)) + return hex.EncodeToString(sum[:]) +} + +// ResolveCaller returns the agent identity the cost-tracking row is +// attributed to. Order of resolution mirrors the maimcp identity logic: +// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then +// "unknown". +func ResolveCaller() string { + if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" { + return v + } + if pane := os.Getenv("TMUX_PANE"); pane != "" { + out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output() + if err == nil { + if name := strings.TrimSpace(string(out)); name != "" { + return name + } + } + } + return "unknown" +} + +// backoffFor429 honours a Retry-After header (in seconds) when present +// and within reason, otherwise falls back to the caller's backoff. +func backoffFor429(retryAfter string, fallback time.Duration) time.Duration { + if retryAfter == "" { + return fallback + } + d, err := time.ParseDuration(retryAfter + "s") + if err != nil || d <= 0 { + return fallback + } + if d > 30*time.Second { + return 30 * time.Second + } + return d +} + +func sleepCtx(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(d): + return true + } +} + +func bytesReader(b []byte) io.Reader { + if b == nil { + return nil + } + return bytes.NewReader(b) +} + +// shortPath strips the host so error messages don't leak the API base. +func shortPath(u string) string { + if i := strings.Index(u, "/v1/"); i >= 0 { + return u[i:] + } + return u +} + +func init() { + Register(ReplicateType, NewReplicate) +} diff --git a/internal/backend/replicate_pricing.go b/internal/backend/replicate_pricing.go new file mode 100644 index 0000000..9b05cf9 --- /dev/null +++ b/internal/backend/replicate_pricing.go @@ -0,0 +1,42 @@ +package backend + +import "strings" + +// Replicate pricing snapshot. +// +// Source: https://replicate.com/pricing and the per-model "Run" tab on +// each model page. Replicate bills per second of GPU time, but the +// black-forest-labs FLUX models also publish a flat per-image price for +// the typical settings — that flat number is what we hardcode here. +// +// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the +// rates drift more than ~10%, update the table and bump snapshotDate. +const replicatePricingSnapshotDate = "2026-05-08" + +// replicatePerImageUSD is the per-image cost estimate keyed by Replicate +// model identifier ("owner/name", with any ":version" trimmed). Returns +// the rate and true if the model is known, 0 and false otherwise — an +// unknown model writes a row with NULL cost rather than a wrong number. +func replicatePerImageUSD(model string) (float64, bool) { + key := normalisePricingKey(model) + switch key { + case "black-forest-labs/flux-schnell": + return 0.003, true + case "black-forest-labs/flux-dev": + return 0.025, true + case "black-forest-labs/flux-pro": + return 0.055, true + case "black-forest-labs/flux-1.1-pro": + return 0.040, true + } + return 0, false +} + +// normalisePricingKey strips the optional ":version" suffix and lowercases +// the owner/name pair. "Owner/Name:hash" → "owner/name". +func normalisePricingKey(model string) string { + if i := strings.IndexByte(model, ':'); i >= 0 { + model = model[:i] + } + return strings.ToLower(model) +} diff --git a/internal/backend/replicate_test.go b/internal/backend/replicate_test.go new file mode 100644 index 0000000..a70a05a --- /dev/null +++ b/internal/backend/replicate_test.go @@ -0,0 +1,675 @@ +package backend + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// fakeReplicate is a programmable mock of the Replicate REST API. +type fakeReplicate struct { + t *testing.T + + mu sync.Mutex + + // Responses for the predictions submission. If the request matches the + // model-based path, modelEndpointHits increments; for /v1/predictions, + // versionEndpointHits increments. + createStatus int + createBody []byte + createCalls atomic.Int32 + + // Auth-fail policy: if true, every request to the prediction endpoints + // returns 401 with a stock body. + auth401 bool + + // 429-then-OK policy: first N create calls return 429. + create429Until int32 + retryAfter string + + // Sequence of /predictions/{id} responses, walked in order; once + // exhausted, the last entry is returned indefinitely. + pollResponses []string + pollIdx atomic.Int32 + pollCalls atomic.Int32 + + // Image download policy. + imageStatus int + imageBody []byte + image5xxFirst int32 + imageCalls atomic.Int32 + + server *httptest.Server +} + +func (f *fakeReplicate) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"): + f.handleCreate(w, r) + case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions": + f.handleCreate(w, r) + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"): + f.handlePoll(w, r) + case r.Method == http.MethodGet && r.URL.Path == "/img": + f.handleImage(w, r) + default: + f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path) + http.NotFound(w, r) + } + }) +} + +func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) { + n := f.createCalls.Add(1) + if f.auth401 { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"detail":"Invalid token"}`)) + return + } + if n <= f.create429Until { + if f.retryAfter != "" { + w.Header().Set("Retry-After", f.retryAfter) + } + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"detail":"too many requests"}`)) + return + } + w.WriteHeader(f.createStatus) + _, _ = w.Write(f.createBody) +} + +func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) { + f.pollCalls.Add(1) + idx := int(f.pollIdx.Add(1)) - 1 + if idx >= len(f.pollResponses) { + idx = len(f.pollResponses) - 1 + } + if idx < 0 { + http.Error(w, "no poll response configured", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(f.pollResponses[idx])) +} + +func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) { + n := f.imageCalls.Add(1) + if n <= f.image5xxFirst { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("upstream unavailable")) + return + } + w.Header().Set("Content-Type", "image/png") + w.WriteHeader(f.imageStatus) + _, _ = w.Write(f.imageBody) +} + +func (f *fakeReplicate) start() { + f.server = httptest.NewServer(f.handler()) + f.t.Cleanup(f.server.Close) +} + +func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" } + +func newFakeReplicate(t *testing.T) *fakeReplicate { + t.Helper() + f := &fakeReplicate{ + t: t, + createStatus: http.StatusCreated, + imageStatus: http.StatusOK, + imageBody: mustPNG(t, 8, 8), + } + f.start() + // Default happy-path responses now that the server URL is known. + f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`) + f.pollResponses = []string{ + `{"id":"pred-abc","status":"starting","version":"v1","output":null}`, + `{"id":"pred-abc","status":"processing","version":"v1","output":null}`, + fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()), + } + return f +} + +func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate { + t.Helper() + be, err := NewReplicate("flux-test", map[string]any{ + "api_token_env": "TEST_REPLICATE_TOKEN", + "model": model, + "api_base": f.server.URL, + }) + if err != nil { + t.Fatalf("NewReplicate: %v", err) + } + r := be.(*Replicate) + r.apiToken = "fake-token" + r.pollInterval = time.Millisecond + r.pollTimeout = 5 * time.Second + r.initialBackoff = time.Millisecond + r.randSeed = func() int64 { return 42 } + return r +} + +func TestReplicateConstructorRejectsBadInputs(t *testing.T) { + if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil { + t.Errorf("expected error for empty instance name") + } + if _, err := NewReplicate("x", map[string]any{}); err == nil { + t.Errorf("expected error for missing model") + } + if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil { + t.Errorf("expected error for malformed model") + } +} + +func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) { + f := newFakeReplicate(t) + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + r.apiToken = "" // simulate the env var being unset + _, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected error when token is missing") + } + if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") { + t.Errorf("error should name the env var: %v", err) + } +} + +func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) { + f := newFakeReplicate(t) + sink := &recordingSink{} + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + r.Sink = sink + + res, err := r.Generate(context.Background(), Request{ + Prompt: "a tiny dragon", + Width: 1024, Height: 1024, + Seed: 0, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + defer res.ImageReader.Close() + body, _ := io.ReadAll(res.ImageReader) + if !bytes.Equal(body, f.imageBody) { + t.Errorf("image body did not round-trip") + } + + if mime := res.MimeType; mime != "image/png" { + t.Errorf("mime = %q", mime) + } + if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" { + t.Errorf("metadata model = %v", got) + } + if got, _ := res.Metadata["model_version"].(string); got != "v1" { + t.Errorf("metadata model_version = %v", got) + } + if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 { + t.Errorf("metadata predict_time_seconds = %v", got) + } + if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 { + t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok) + } + if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" { + t.Errorf("aspect_ratio = %q", got) + } + + if got := f.pollCalls.Load(); got < 3 { + t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got) + } + + if len(sink.rows) != 1 { + t.Fatalf("expected 1 sink row, got %d", len(sink.rows)) + } + row := sink.rows[0] + if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" { + t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model) + } + if row.PromptHash == "" || row.PromptHash == "a tiny dragon" { + t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash) + } + if len(row.PromptHash) != 64 { + t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash)) + } + if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 { + t.Errorf("sink row cost = %v", row.CostUSDEstimate) + } + if row.LatencyMs <= 0 { + t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs) + } +} + +func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) { + f := newFakeReplicate(t) + + var sentBody []byte + var sentPath string + mu := sync.Mutex{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions": + mu.Lock() + sentBody, _ = io.ReadAll(r.Body) + sentPath = r.URL.Path + mu.Unlock() + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`)) + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"): + _, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL()) + default: + f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + })) + t.Cleanup(srv.Close) + + be, err := NewReplicate("flux-test", map[string]any{ + "model": "black-forest-labs/flux-dev:abc123", + "api_base": srv.URL, + }) + if err != nil { + t.Fatalf("NewReplicate: %v", err) + } + r := be.(*Replicate) + r.apiToken = "fake" + r.pollInterval = time.Millisecond + + res, err := r.Generate(context.Background(), Request{Prompt: "x"}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + res.ImageReader.Close() + + mu.Lock() + defer mu.Unlock() + if sentPath != "/v1/predictions" { + t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath) + } + var body map[string]any + if err := json.Unmarshal(sentBody, &body); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if body["version"] != "abc123" { + t.Errorf("expected version=abc123 in body, got %v", body["version"]) + } +} + +func TestReplicate401SurfacesEnvHint(t *testing.T) { + f := newFakeReplicate(t) + f.auth401 = true + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + + _, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected 401 to surface as error") + } + msg := err.Error() + if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") { + t.Errorf("error should name env var: %v", err) + } + if !strings.Contains(msg, "401") { + t.Errorf("error should mention 401: %v", err) + } +} + +func TestReplicate429RetriesThenSucceeds(t *testing.T) { + f := newFakeReplicate(t) + f.create429Until = 2 // first two calls 429 then succeed + f.retryAfter = "" // force the adapter's exp-backoff path + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + r.pollInterval = time.Millisecond + + // Squash the backoff so the test is fast. + r.httpClient = &http.Client{Timeout: 5 * time.Second} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + res, err := r.Generate(ctx, Request{Prompt: "p"}) + if err != nil { + t.Fatalf("Generate after 429s: %v", err) + } + res.ImageReader.Close() + if got := f.createCalls.Load(); got != 3 { + t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got) + } +} + +func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) { + f := newFakeReplicate(t) + f.create429Until = 99 // every call 429 + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _, err := r.Generate(ctx, Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected error after sustained 429s") + } + if !strings.Contains(err.Error(), "429") { + t.Errorf("expected 429 in error: %v", err) + } + // max429Retries=3 → 1 initial + 3 retries = 4 total + if got := f.createCalls.Load(); got != 4 { + t.Errorf("expected 4 create calls (1+3 retries), got %d", got) + } +} + +func TestReplicateFailedPredictionSurfacesError(t *testing.T) { + f := newFakeReplicate(t) + f.pollResponses = []string{ + `{"id":"pred-abc","status":"starting","version":"v1","output":null}`, + `{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`, + } + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + _, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected error for failed prediction") + } + if !strings.Contains(err.Error(), "failed") { + t.Errorf("error should mention failure: %v", err) + } + if !strings.Contains(err.Error(), "NSFW") { + t.Errorf("error should include the API error message: %v", err) + } +} + +func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) { + f := newFakeReplicate(t) + // Always return processing → adapter times out. + f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`} + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + r.pollInterval = 5 * time.Millisecond + r.pollTimeout = 30 * time.Millisecond + + _, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected timeout error") + } + if !strings.Contains(err.Error(), "did not complete") { + t.Errorf("expected 'did not complete' in error, got %v", err) + } + if !strings.Contains(err.Error(), "waited") { + t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err) + } +} + +func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) { + f := newFakeReplicate(t) + f.image5xxFirst = 1 // first download 502, second OK + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + + res, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err != nil { + t.Fatalf("Generate (download retry): %v", err) + } + res.ImageReader.Close() + if got := f.imageCalls.Load(); got != 2 { + t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got) + } +} + +func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) { + f := newFakeReplicate(t) + f.image5xxFirst = 99 // every download fails + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + + _, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected error after sustained image-download 5xx") + } + if got := f.imageCalls.Load(); got != 2 { + t.Errorf("expected 2 image fetches (no further retries), got %d", got) + } +} + +func TestReplicateContextCancelStopsPolling(t *testing.T) { + f := newFakeReplicate(t) + f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`} + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + r.pollInterval = 5 * time.Millisecond + r.pollTimeout = 5 * time.Second + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + _, err := r.Generate(ctx, Request{Prompt: "p"}) + if err == nil { + t.Fatal("expected ctx error") + } + if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("expected deadline exceeded, got %v", err) + } +} + +func TestReplicateBackendOptsMergedIntoInput(t *testing.T) { + var captured map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"): + body, _ := io.ReadAll(r.Body) + var top struct { + Input map[string]any `json:"input"` + } + _ = json.Unmarshal(body, &top) + captured = top.Input + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`)) + } + })) + t.Cleanup(srv.Close) + + be, err := NewReplicate("flux-test", map[string]any{ + "model": "black-forest-labs/flux-schnell", + "api_base": srv.URL, + }) + if err != nil { + t.Fatalf("NewReplicate: %v", err) + } + r := be.(*Replicate) + r.apiToken = "fake" + r.pollInterval = time.Millisecond + + // We expect this to fail at "pickFirstOutputURL" (empty output) but the + // captured input should still have been recorded. + _, _ = r.Generate(context.Background(), Request{ + Prompt: "p", + Width: 1024, Height: 1024, + BackendOpts: map[string]any{ + "output_quality": 90, + "go_fast": true, + }, + }) + if captured == nil { + t.Fatal("create endpoint not hit") + } + if captured["output_quality"] != float64(90) { + t.Errorf("output_quality not threaded: %v", captured["output_quality"]) + } + if captured["go_fast"] != true { + t.Errorf("go_fast not threaded: %v", captured["go_fast"]) + } + if captured["prompt"] != "p" { + t.Errorf("prompt not threaded: %v", captured["prompt"]) + } +} + +func TestReplicateOutputArrayAccepted(t *testing.T) { + f := newFakeReplicate(t) + f.pollResponses = []string{ + fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()), + } + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + res, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err != nil { + t.Fatalf("Generate (output as array): %v", err) + } + res.ImageReader.Close() +} + +func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) { + f := newFakeReplicate(t) + r := newReplicate(t, f, "stability-ai/sdxl") + sink := &recordingSink{} + r.Sink = sink + + res, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err != nil { + t.Fatalf("Generate (unknown model): %v", err) + } + res.ImageReader.Close() + if _, present := res.Metadata["cost_usd_estimate"]; present { + t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"]) + } + if len(sink.rows) != 1 { + t.Fatalf("expected 1 sink row, got %d", len(sink.rows)) + } + if sink.rows[0].CostUSDEstimate != nil { + t.Errorf("expected nil cost in sink row for unknown model") + } +} + +func TestReplicateSinkFailureIsWarningNotError(t *testing.T) { + f := newFakeReplicate(t) + r := newReplicate(t, f, "black-forest-labs/flux-schnell") + r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") }) + + res, err := r.Generate(context.Background(), Request{Prompt: "p"}) + if err != nil { + t.Fatalf("sink failure should not fail Generate: %v", err) + } + res.ImageReader.Close() +} + +func TestReplicateDefaultStepsApplied(t *testing.T) { + var captured map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"): + body, _ := io.ReadAll(r.Body) + var top struct { + Input map[string]any `json:"input"` + } + _ = json.Unmarshal(body, &top) + captured = top.Input + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`)) + } + })) + t.Cleanup(srv.Close) + + be, err := NewReplicate("flux-test", map[string]any{ + "model": "black-forest-labs/flux-dev", + "api_base": srv.URL, + "default_steps": 28, + }) + if err != nil { + t.Fatalf("NewReplicate: %v", err) + } + r := be.(*Replicate) + r.apiToken = "fake" + r.pollInterval = time.Millisecond + _, _ = r.Generate(context.Background(), Request{Prompt: "p"}) + if captured == nil { + t.Fatal("create endpoint not hit") + } + if captured["num_inference_steps"] != float64(28) { + t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"]) + } +} + +func TestComputeAspectRatio(t *testing.T) { + cases := []struct { + w, h int + fallback string + want string + }{ + {1024, 1024, "1:1", "1:1"}, + {1920, 1080, "1:1", "16:9"}, + {2560, 1440, "1:1", "16:9"}, + {1024, 768, "1:1", "4:3"}, + {1024, 1280, "1:1", "4:5"}, + {1000, 1234, "1:1", "1:1"}, // weird ratio falls back + {0, 1024, "1:1", "1:1"}, + } + for _, c := range cases { + got := computeAspectRatio(c.w, c.h, c.fallback) + if got != c.want { + t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want) + } + } +} + +func TestParseModelRef(t *testing.T) { + owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell") + if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" { + t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err) + } + owner, name, ver, err = parseModelRef("owner/name:hash123") + if err != nil || owner != "owner" || name != "name" || ver != "hash123" { + t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err) + } + if _, _, _, err := parseModelRef("noslash"); err == nil { + t.Errorf("expected error for malformed ref") + } +} + +func TestHashPromptStable(t *testing.T) { + a := hashPrompt("hello") + b := hashPrompt("hello") + c := hashPrompt("hello!") + if a != b { + t.Errorf("hashPrompt should be deterministic") + } + if a == c { + t.Errorf("different prompts should hash differently") + } + if len(a) != 64 { + t.Errorf("sha256 hex should be 64 chars, got %d", len(a)) + } +} + +func TestReplicatePricingKnownModels(t *testing.T) { + if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 { + t.Errorf("schnell rate = %v (ok=%v)", v, ok) + } + if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 { + t.Errorf("dev rate = %v (ok=%v)", v, ok) + } + if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 { + t.Errorf("versioned ref should resolve to base price: %v %v", v, ok) + } + if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok { + t.Errorf("unknown model should report ok=false") + } +} + +func TestReplicateTypeIsRegistered(t *testing.T) { + if !Default.Has(ReplicateType) { + t.Errorf("replicate type not registered in Default") + } +} + +// recordingSink captures rows for assertion. +type recordingSink struct { + mu sync.Mutex + rows []UsageRow +} + +func (s *recordingSink) Record(_ context.Context, row UsageRow) error { + s.mu.Lock() + defer s.mu.Unlock() + s.rows = append(s.rows, row) + return nil +} + +type sinkFunc func(context.Context, UsageRow) error + +func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) } diff --git a/internal/config/config.go b/internal/config/config.go index 15c08f6..5c86149 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -131,11 +131,19 @@ backends: mock: type: mock + flux-schnell-replicate: + type: replicate + api_token_env: REPLICATE_API_TOKEN + model: black-forest-labs/flux-schnell + default_steps: 4 + default_aspect_ratio: "1:1" + flux-dev-replicate: type: replicate api_token_env: REPLICATE_API_TOKEN model: black-forest-labs/flux-dev default_steps: 28 + default_aspect_ratio: "1:1" dalle3: type: openai diff --git a/internal/usage/usage.go b/internal/usage/usage.go new file mode 100644 index 0000000..5f43030 --- /dev/null +++ b/internal/usage/usage.go @@ -0,0 +1,160 @@ +// Package usage records per-call cost-tracking rows for the imagen CLI +// to mai.imagen_usage on Supabase. The writer is best-effort by design — +// the calling adapter logs failures and proceeds, because the image +// itself has already landed on disk by the time we record. +package usage + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + "mgit.msbls.de/m/ImaGen/internal/backend" +) + +// Default REST schema is the mai schema where mai.imagen_usage lives. +const supabaseSchema = "mai" + +// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/ +// Content-Profile headers to target the mai schema instead of public. +type SupabaseSink struct { + URL string // SUPABASE_URL — e.g. https://msup.msbls.de + APIKey string // SUPABASE_SERVICE_KEY + HTTP *http.Client +} + +// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY +// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use. +// Returns nil + ok=false if the env vars are not configured — the CLI +// uses that to skip cost-tracking gracefully. +func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) { + u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/") + if u == "" { + return nil, false + } + key := os.Getenv("SUPABASE_SERVICE_KEY") + if key == "" { + key = os.Getenv("MAI_SUPABASE_KEY") + } + if key == "" { + return nil, false + } + return &SupabaseSink{ + URL: u, + APIKey: key, + HTTP: &http.Client{Timeout: 10 * time.Second}, + }, true +} + +type supabaseRow struct { + Backend string `json:"backend"` + Model string `json:"model"` + Seed *int64 `json:"seed,omitempty"` + PromptHash string `json:"prompt_hash"` + LatencyMs int `json:"latency_ms"` + CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"` + Caller string `json:"caller,omitempty"` +} + +// Record inserts one row into mai.imagen_usage. +func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error { + body, err := json.Marshal(supabaseRow{ + Backend: row.Backend, + Model: row.Model, + Seed: row.Seed, + PromptHash: row.PromptHash, + LatencyMs: row.LatencyMs, + CostUSDEstimate: row.CostUSDEstimate, + Caller: row.Caller, + }) + if err != nil { + return fmt.Errorf("usage: marshal: %w", err) + } + + endpoint := s.URL + "/rest/v1/imagen_usage" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("apikey", s.APIKey) + req.Header.Set("Authorization", "Bearer "+s.APIKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Profile", supabaseSchema) + req.Header.Set("Content-Profile", supabaseSchema) + req.Header.Set("Prefer", "return=minimal") + + resp, err := s.HTTP.Do(req) + if err != nil { + return fmt.Errorf("usage: POST: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody)) + } + return nil +} + +// Row is the read-side row shape (only the fields the CLI needs). +type Row struct { + CreatedAt time.Time `json:"created_at"` + Backend string `json:"backend"` + Model string `json:"model"` + Seed *int64 `json:"seed"` + PromptHash string `json:"prompt_hash"` + LatencyMs *int `json:"latency_ms"` + CostUSDEstimate *float64 `json:"cost_usd_estimate"` + Caller *string `json:"caller"` +} + +// Query returns rows from mai.imagen_usage filtered by created_at >= since. +// Pass zero time to fetch the full table (capped server-side by PostgREST +// — we set a hard 5000-row limit here too). +func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) { + q := url.Values{} + q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller") + q.Set("order", "created_at.desc") + q.Set("limit", "5000") + if !since.IsZero() { + q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339)) + } + endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("apikey", s.APIKey) + req.Header.Set("Authorization", "Bearer "+s.APIKey) + req.Header.Set("Accept-Profile", supabaseSchema) + + resp, err := s.HTTP.Do(req) + if err != nil { + return nil, fmt.Errorf("usage: GET: %w", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body)) + } + var rows []Row + if err := json.Unmarshal(body, &rows); err != nil { + return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body)) + } + return rows, nil +} + +func snip(b []byte) string { + const max = 500 + s := strings.TrimSpace(string(b)) + if len(s) > max { + s = s[:max] + "..." + } + return s +}