Merge mai/hermes/issue-3-imagen-3: Replicate API backend + cost-tracking + usage CLI (#3)

This commit is contained in:
mAi
2026-05-08 17:32:09 +02:00
10 changed files with 1710 additions and 8 deletions

View File

@@ -10,6 +10,27 @@ import (
"mgit.msbls.de/m/ImaGen/internal/config" "mgit.msbls.de/m/ImaGen/internal/config"
) )
// instanceStatus checks adapter-specific preconditions (e.g. the
// Replicate API token env var being set) and returns a short
// user-facing status string.
func instanceStatus(spec config.BackendSpec) string {
if !backend.Default.Has(spec.Type) {
return fmt.Sprintf("type %q not compiled in", spec.Type)
}
switch spec.Type {
case backend.ReplicateType:
envName, _ := spec.Raw["api_token_env"].(string)
if envName == "" {
envName = "REPLICATE_API_TOKEN"
}
if os.Getenv(envName) == "" {
return fmt.Sprintf("not configured (set %s)", envName)
}
return "ok"
}
return "registered"
}
func runBackends(args []string) error { func runBackends(args []string) error {
fs := flag.NewFlagSet("backends", flag.ContinueOnError) fs := flag.NewFlagSet("backends", flag.ContinueOnError)
var configPath string var configPath string
@@ -27,10 +48,7 @@ func runBackends(args []string) error {
fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS") fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS")
if cfg != nil { if cfg != nil {
for name, spec := range cfg.Backends { for name, spec := range cfg.Backends {
status := "registered" status := instanceStatus(spec)
if !backend.Default.Has(spec.Type) {
status = fmt.Sprintf("type %q not compiled in", spec.Type)
}
marker := "" marker := ""
if name == cfg.DefaultBackend { if name == cfg.DefaultBackend {
marker = " (default)" marker = " (default)"

View File

@@ -13,6 +13,7 @@ import (
"mgit.msbls.de/m/ImaGen/internal/output" "mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/preview" "mgit.msbls.de/m/ImaGen/internal/preview"
"mgit.msbls.de/m/ImaGen/internal/prompt" "mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/usage"
) )
func runGenerate(ctx context.Context, args []string) error { func runGenerate(ctx context.Context, args []string) error {
@@ -81,6 +82,7 @@ func runGenerate(ctx context.Context, args []string) error {
if err != nil { if err != nil {
return err return err
} }
attachUsageSink(be)
finalPrompt, err := prompt.Apply(rawPrompt, style) finalPrompt, err := prompt.Apply(rawPrompt, style)
if err != nil { if err != nil {
@@ -219,6 +221,21 @@ func parseSize(s string) (int, int, error) {
return w, h, nil return w, h, nil
} }
// attachUsageSink wires a Supabase cost-tracking sink into the backend
// when it accepts one and the env is configured. Adapters that record
// usage expose a public Sink field of type backend.UsageSink.
func attachUsageSink(be backend.Backend) {
r, ok := be.(*backend.Replicate)
if !ok {
return
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return
}
r.Sink = sink
}
func buildBackend(cfg *config.Config, name string) (backend.Backend, error) { func buildBackend(cfg *config.Config, name string) (backend.Backend, error) {
if cfg != nil { if cfg != nil {
spec, ok := cfg.Backends[name] spec, ok := cfg.Backends[name]

View File

@@ -14,7 +14,7 @@ import (
_ "mgit.msbls.de/m/ImaGen/internal/backend" _ "mgit.msbls.de/m/ImaGen/internal/backend"
) )
const usage = `imagen — model-agnostic image generation const helpText = `imagen — model-agnostic image generation
Usage: Usage:
imagen generate <prompt> [flags] generate one image imagen generate <prompt> [flags] generate one image
@@ -22,6 +22,7 @@ Usage:
imagen config init print a sample imagen.yaml on stdout imagen config init print a sample imagen.yaml on stdout
imagen config validate validate the active config imagen config validate validate the active config
imagen serve [--addr :8080] (stub) start the HTTP server imagen serve [--addr :8080] (stub) start the HTTP server
imagen usage [--since DATE] show cost-tracking rows
imagen version print version imagen version print version
imagen help show this help imagen help show this help
@@ -33,7 +34,7 @@ var Version = "dev"
func main() { func main() {
if len(os.Args) < 2 { if len(os.Args) < 2 {
fmt.Fprint(os.Stderr, usage) fmt.Fprint(os.Stderr, helpText)
os.Exit(2) os.Exit(2)
} }
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
@@ -50,12 +51,14 @@ func main() {
err = runConfig(args) err = runConfig(args)
case "serve": case "serve":
err = runServe(args) err = runServe(args)
case "usage":
err = runUsage(ctx, args)
case "version", "-v", "--version": case "version", "-v", "--version":
fmt.Println(Version) fmt.Println(Version)
case "help", "-h", "--help": case "help", "-h", "--help":
fmt.Print(usage) fmt.Print(helpText)
default: default:
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage) fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText)
os.Exit(2) os.Exit(2)
} }
if err != nil { if err != nil {

189
cmd/imagen/usage.go Normal file
View File

@@ -0,0 +1,189 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"sort"
"strings"
"text/tabwriter"
"time"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage
// via Supabase REST and prints a tab-aligned table grouped by week +
// backend + model + caller, with totals at the bottom.
func runUsage(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("usage", flag.ContinueOnError)
var (
since string
raw bool
)
fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date")
fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]")
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
var sinceT time.Time
if since != "" {
t, err := time.Parse("2006-01-02", since)
if err != nil {
return userErr("--since must be YYYY-MM-DD: %v", err)
}
sinceT = t
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage")
}
rows, err := sink.Query(ctx, sinceT)
if err != nil {
return err
}
if raw {
printRawRows(rows)
return nil
}
printGroupedRows(rows)
return nil
}
func printRawRows(rows []usage.Row) {
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD")
var totalCost float64
for _, r := range rows {
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n",
r.CreatedAt.Local().Format("2006-01-02 15:04"),
r.Backend,
r.Model,
derefString(r.Caller),
intOrDash(r.LatencyMs),
costOrDash(r.CostUSDEstimate),
)
if r.CostUSDEstimate != nil {
totalCost += *r.CostUSDEstimate
}
}
fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost)
_ = tw.Flush()
}
type group struct {
week string
backend string
model string
caller string
count int
cost float64
costSet bool
}
type groupKey struct {
week, backend, model, caller string
}
func printGroupedRows(rows []usage.Row) {
groups := map[groupKey]*group{}
for _, r := range rows {
caller := derefString(r.Caller)
k := groupKey{
week: weekStart(r.CreatedAt).Format("2006-01-02"),
backend: r.Backend,
model: r.Model,
caller: caller,
}
g, ok := groups[k]
if !ok {
g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller}
groups[k] = g
}
g.count++
if r.CostUSDEstimate != nil {
g.cost += *r.CostUSDEstimate
g.costSet = true
}
}
keys := make([]groupKey, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].week != keys[j].week {
return keys[i].week > keys[j].week // newest first
}
if keys[i].backend != keys[j].backend {
return keys[i].backend < keys[j].backend
}
if keys[i].model != keys[j].model {
return keys[i].model < keys[j].model
}
return keys[i].caller < keys[j].caller
})
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD")
var totalCount int
var totalCost float64
for _, k := range keys {
g := groups[k]
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n",
g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet),
)
totalCount += g.count
totalCost += g.cost
}
fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost)
_ = tw.Flush()
}
// weekStart returns the Monday of the week containing t (UTC).
func weekStart(t time.Time) time.Time {
t = t.UTC()
wd := int(t.Weekday())
if wd == 0 {
wd = 7 // shift Sunday to end-of-week
}
delta := time.Duration(wd-1) * -24 * time.Hour
d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
return d.Add(delta)
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
func intOrDash(p *int) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%d", *p)
}
func costOrDash(p *float64) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%.4f", *p)
}
func costStr(v float64, set bool) string {
if !set {
return "-"
}
return strings.TrimSpace(fmt.Sprintf("%.4f", v))
}

View File

@@ -91,3 +91,26 @@ API-backed adapters read tokens from env vars referenced by the config
export REPLICATE_API_TOKEN=... export REPLICATE_API_TOKEN=...
imagen generate "a cat" --backend flux-dev-replicate imagen generate "a cat" --backend flux-dev-replicate
``` ```
## Cost-tracking (Replicate)
Successful generations through the Replicate adapter write one row to
`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost
estimate, prompt sha256 hash (never the prompt itself), and the caller
identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`).
The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are
unset, or the database write fails, the image still lands and the CLI
prints a warning to stderr.
Inspect spend:
```sh
imagen usage # all rows, grouped by week + backend + model + caller
imagen usage --since 2026-05-01 # only rows on/after a UTC date
imagen usage --since 2026-05-01 --raw
```
Per-model rates live in `internal/backend/replicate_pricing.go` — they
are snapshotted from <https://replicate.com/pricing> and refreshed on a
quarterly cadence.

View File

@@ -0,0 +1,567 @@
package backend
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"time"
)
// ReplicateType is the type-name adapters register under for Replicate
// instances.
const ReplicateType = "replicate"
// Replicate is the Replicate API adapter. It speaks the public REST API
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
// to submit, then polls /v1/predictions/{id} until the prediction
// settles, then downloads the produced image.
//
// Concurrency: a single Replicate is safe to share across goroutines.
type Replicate struct {
instance string
apiBase string
apiToken string
tokenEnv string
model string // "owner/name" or "owner/name:version-hash"
owner string
name string
version string // optional; empty means "use the model-based predictions endpoint"
defaultSteps int
defaultAspect string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
initialBackoff time.Duration
// Sink is where successful generations are recorded for cost-tracking.
// nil means "do not record". The framework wires this up in the CLI;
// adapter tests inject a fake.
Sink UsageSink
}
// UsageSink writes one row per successful generation. Implementations
// should treat write failures as warnings, not errors — the image has
// already landed on disk; failing the call would lose the artefact.
type UsageSink interface {
Record(ctx context.Context, row UsageRow) error
}
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
// prompt itself is intentionally NOT included — only the sha256 hash.
type UsageRow struct {
Backend string
Model string
Seed *int64
PromptHash string
LatencyMs int
CostUSDEstimate *float64
Caller string
}
// NewReplicate is the registry constructor. cfg is the adapter's slice
// of imagen.yaml.
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("replicate: empty instance name")
}
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
}
owner, modelName, version, err := parseModelRef(model)
if err != nil {
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
}
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
pollTimeout := timeoutForModel(modelName)
r := &Replicate{
instance: name,
apiBase: apiBase,
tokenEnv: tokenEnv,
apiToken: os.Getenv(tokenEnv),
model: model,
owner: owner,
name: modelName,
version: version,
defaultSteps: getInt(cfg, "default_steps", 0),
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
httpClient: &http.Client{Timeout: 30 * time.Second},
pollInterval: 500 * time.Millisecond,
pollTimeout: pollTimeout,
randSeed: cryptoSeed,
initialBackoff: time.Second,
}
return r, nil
}
// Name returns the instance name from imagen.yaml.
func (r *Replicate) Name() string { return r.instance }
// Generate submits one prediction and returns the resulting PNG.
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
if r.apiToken == "" {
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
}
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
aspect := computeAspectRatio(width, height, r.defaultAspect)
steps := orDefaultInt(req.Steps, r.defaultSteps)
seed := req.Seed
if seed == 0 {
seed = r.randSeed()
}
input := map[string]any{
"prompt": req.Prompt,
"aspect_ratio": aspect,
"num_outputs": 1,
"output_format": "png",
"seed": seed,
}
if steps > 0 {
input["num_inference_steps"] = steps
}
if req.NegativePrompt != "" {
input["negative_prompt"] = req.NegativePrompt
}
maps.Copy(input, req.BackendOpts)
start := time.Now()
pred, err := r.submitPrediction(ctx, input)
if err != nil {
return nil, err
}
final, err := r.waitForCompletion(ctx, pred.ID)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgURL, err := pickFirstOutputURL(final.Output)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
if err != nil {
return nil, err
}
latencyMs := int(time.Since(start).Milliseconds())
predictTime := final.Metrics.PredictTime
costEst, costKnown := replicatePerImageUSD(r.model)
meta := map[string]any{
"backend": r.instance,
"backend_type": ReplicateType,
"model": r.model,
"model_version": final.Version,
"prediction_id": pred.ID,
"seed": seed,
"width": width,
"height": height,
"aspect_ratio": aspect,
"latency_ms": int64(latencyMs),
}
if predictTime > 0 {
meta["predict_time_seconds"] = predictTime
}
if costKnown {
meta["cost_usd_estimate"] = costEst
}
if r.Sink != nil {
row := UsageRow{
Backend: r.instance,
Model: r.model,
PromptHash: hashPrompt(req.Prompt),
LatencyMs: latencyMs,
Caller: ResolveCaller(),
}
row.Seed = new(int64)
*row.Seed = seed
if costKnown {
c := costEst
row.CostUSDEstimate = &c
}
if err := r.Sink.Record(ctx, row); err != nil {
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
}
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: mime,
Metadata: meta,
}, nil
}
// replicatePrediction is what the REST API returns under /v1/predictions
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
// declared.
type replicatePrediction struct {
ID string `json:"id"`
Status string `json:"status"`
Version string `json:"version"`
Error json.RawMessage `json:"error"`
Output json.RawMessage `json:"output"`
Metrics struct {
PredictTime float64 `json:"predict_time"`
} `json:"metrics"`
}
// submitPrediction creates a prediction and returns it. Picks the
// model-based endpoint when no version was given (recommended for
// Replicate's official models), and the legacy /v1/predictions otherwise.
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
var (
reqURL string
body []byte
err error
)
if r.version == "" {
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
body, err = json.Marshal(map[string]any{"input": input})
} else {
reqURL = r.apiBase + "/v1/predictions"
body, err = json.Marshal(map[string]any{
"version": r.version,
"input": input,
})
}
if err != nil {
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
}
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(respBody, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
}
if pred.ID == "" {
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
}
return &pred, nil
}
// waitForCompletion polls /v1/predictions/{id} until the status is a
// terminal value (succeeded, failed, canceled) or the timeout fires.
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
deadline := time.Now().Add(r.pollTimeout)
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
start := time.Now()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if time.Now().After(deadline) {
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
}
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(body, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
}
switch pred.Status {
case "succeeded":
return &pred, nil
case "failed":
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
case "canceled":
return nil, fmt.Errorf("status=canceled")
case "starting", "processing", "":
default:
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(r.pollInterval):
}
}
}
// doWithRetry executes a Replicate API request and applies the resilience
// policy: 401 surfaces a clean message naming the env var; 429 retries
// with exponential backoff up to three times; 5xx retries once; other
// errors surface unchanged.
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
const max429Retries = 3
backoff := r.initialBackoff
if backoff <= 0 {
backoff = time.Second
}
var lastErr error
for attempt := 0; ; attempt++ {
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Token "+r.apiToken)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
}
lastErr = err
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return respBody, nil
case resp.StatusCode == http.StatusUnauthorized:
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
case resp.StatusCode == http.StatusTooManyRequests:
if attempt >= max429Retries {
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
}
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
if !sleepCtx(ctx, wait) {
return nil, ctx.Err()
}
backoff *= 2
continue
case resp.StatusCode >= 500:
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
default:
_ = lastErr
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
}
}
// fetchImage downloads the rendered image from the Replicate-provided
// CDN URL. One retry on a generic network error.
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
var lastErr error
for attempt := range 2 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
if err != nil {
return nil, "", err
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, "", err
}
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
lastErr = readErr
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if resp.StatusCode >= 500 && attempt == 0 {
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
continue
}
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
}
mime := resp.Header.Get("Content-Type")
if mime == "" {
mime = "image/png"
}
return body, mime, nil
}
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
}
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
func parseModelRef(ref string) (owner, name, version string, err error) {
rest := ref
if i := strings.IndexByte(rest, ':'); i >= 0 {
version = rest[i+1:]
rest = rest[:i]
}
parts := strings.SplitN(rest, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
}
return parts[0], parts[1], version, nil
}
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
// than schnell; everything else gets the dev timeout for safety.
func timeoutForModel(name string) time.Duration {
switch strings.ToLower(name) {
case "flux-schnell":
return 60 * time.Second
default:
return 120 * time.Second
}
}
// pickFirstOutputURL extracts the first image URL from the output field.
// Replicate returns either a single string or an array of strings for image
// models — we accept both.
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", fmt.Errorf("output is empty")
}
var s string
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
return s, nil
}
var arr []string
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
return arr[0], nil
}
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
}
// computeAspectRatio reduces width:height to a Replicate-supported aspect
// ratio when the reduction lands on one of the canonical values; otherwise
// returns fallback.
func computeAspectRatio(w, h int, fallback string) string {
if w <= 0 || h <= 0 {
return fallback
}
g := gcd(w, h)
a, b := w/g, h/g
s := fmt.Sprintf("%d:%d", a, b)
if isReplicateAspectRatio(s) {
return s
}
return fallback
}
func isReplicateAspectRatio(s string) bool {
switch s {
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
return true
}
return false
}
func gcd(a, b int) int {
for b != 0 {
a, b = b, a%b
}
if a < 0 {
return -a
}
return a
}
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
// is intentionally never written to the cost-tracking table.
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// ResolveCaller returns the agent identity the cost-tracking row is
// attributed to. Order of resolution mirrors the maimcp identity logic:
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
// "unknown".
func ResolveCaller() string {
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
return v
}
if pane := os.Getenv("TMUX_PANE"); pane != "" {
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
if err == nil {
if name := strings.TrimSpace(string(out)); name != "" {
return name
}
}
}
return "unknown"
}
// backoffFor429 honours a Retry-After header (in seconds) when present
// and within reason, otherwise falls back to the caller's backoff.
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
if retryAfter == "" {
return fallback
}
d, err := time.ParseDuration(retryAfter + "s")
if err != nil || d <= 0 {
return fallback
}
if d > 30*time.Second {
return 30 * time.Second
}
return d
}
func sleepCtx(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
func bytesReader(b []byte) io.Reader {
if b == nil {
return nil
}
return bytes.NewReader(b)
}
// shortPath strips the host so error messages don't leak the API base.
func shortPath(u string) string {
if i := strings.Index(u, "/v1/"); i >= 0 {
return u[i:]
}
return u
}
func init() {
Register(ReplicateType, NewReplicate)
}

View File

@@ -0,0 +1,42 @@
package backend
import "strings"
// Replicate pricing snapshot.
//
// Source: https://replicate.com/pricing and the per-model "Run" tab on
// each model page. Replicate bills per second of GPU time, but the
// black-forest-labs FLUX models also publish a flat per-image price for
// the typical settings — that flat number is what we hardcode here.
//
// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the
// rates drift more than ~10%, update the table and bump snapshotDate.
const replicatePricingSnapshotDate = "2026-05-08"
// replicatePerImageUSD is the per-image cost estimate keyed by Replicate
// model identifier ("owner/name", with any ":version" trimmed). Returns
// the rate and true if the model is known, 0 and false otherwise — an
// unknown model writes a row with NULL cost rather than a wrong number.
func replicatePerImageUSD(model string) (float64, bool) {
key := normalisePricingKey(model)
switch key {
case "black-forest-labs/flux-schnell":
return 0.003, true
case "black-forest-labs/flux-dev":
return 0.025, true
case "black-forest-labs/flux-pro":
return 0.055, true
case "black-forest-labs/flux-1.1-pro":
return 0.040, true
}
return 0, false
}
// normalisePricingKey strips the optional ":version" suffix and lowercases
// the owner/name pair. "Owner/Name:hash" → "owner/name".
func normalisePricingKey(model string) string {
if i := strings.IndexByte(model, ':'); i >= 0 {
model = model[:i]
}
return strings.ToLower(model)
}

View File

@@ -0,0 +1,675 @@
package backend
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// fakeReplicate is a programmable mock of the Replicate REST API.
type fakeReplicate struct {
t *testing.T
mu sync.Mutex
// Responses for the predictions submission. If the request matches the
// model-based path, modelEndpointHits increments; for /v1/predictions,
// versionEndpointHits increments.
createStatus int
createBody []byte
createCalls atomic.Int32
// Auth-fail policy: if true, every request to the prediction endpoints
// returns 401 with a stock body.
auth401 bool
// 429-then-OK policy: first N create calls return 429.
create429Until int32
retryAfter string
// Sequence of /predictions/{id} responses, walked in order; once
// exhausted, the last entry is returned indefinitely.
pollResponses []string
pollIdx atomic.Int32
pollCalls atomic.Int32
// Image download policy.
imageStatus int
imageBody []byte
image5xxFirst int32
imageCalls atomic.Int32
server *httptest.Server
}
func (f *fakeReplicate) handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
f.handleCreate(w, r)
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
f.handleCreate(w, r)
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
f.handlePoll(w, r)
case r.Method == http.MethodGet && r.URL.Path == "/img":
f.handleImage(w, r)
default:
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
http.NotFound(w, r)
}
})
}
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
n := f.createCalls.Add(1)
if f.auth401 {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
return
}
if n <= f.create429Until {
if f.retryAfter != "" {
w.Header().Set("Retry-After", f.retryAfter)
}
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
return
}
w.WriteHeader(f.createStatus)
_, _ = w.Write(f.createBody)
}
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
f.pollCalls.Add(1)
idx := int(f.pollIdx.Add(1)) - 1
if idx >= len(f.pollResponses) {
idx = len(f.pollResponses) - 1
}
if idx < 0 {
http.Error(w, "no poll response configured", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(f.pollResponses[idx]))
}
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
n := f.imageCalls.Add(1)
if n <= f.image5xxFirst {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte("upstream unavailable"))
return
}
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(f.imageStatus)
_, _ = w.Write(f.imageBody)
}
func (f *fakeReplicate) start() {
f.server = httptest.NewServer(f.handler())
f.t.Cleanup(f.server.Close)
}
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
func newFakeReplicate(t *testing.T) *fakeReplicate {
t.Helper()
f := &fakeReplicate{
t: t,
createStatus: http.StatusCreated,
imageStatus: http.StatusOK,
imageBody: mustPNG(t, 8, 8),
}
f.start()
// Default happy-path responses now that the server URL is known.
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
}
return f
}
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
t.Helper()
be, err := NewReplicate("flux-test", map[string]any{
"api_token_env": "TEST_REPLICATE_TOKEN",
"model": model,
"api_base": f.server.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake-token"
r.pollInterval = time.Millisecond
r.pollTimeout = 5 * time.Second
r.initialBackoff = time.Millisecond
r.randSeed = func() int64 { return 42 }
return r
}
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
t.Errorf("expected error for empty instance name")
}
if _, err := NewReplicate("x", map[string]any{}); err == nil {
t.Errorf("expected error for missing model")
}
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
t.Errorf("expected error for malformed model")
}
}
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.apiToken = "" // simulate the env var being unset
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error when token is missing")
}
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name the env var: %v", err)
}
}
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
f := newFakeReplicate(t)
sink := &recordingSink{}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sink
res, err := r.Generate(context.Background(), Request{
Prompt: "a tiny dragon",
Width: 1024, Height: 1024,
Seed: 0,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
defer res.ImageReader.Close()
body, _ := io.ReadAll(res.ImageReader)
if !bytes.Equal(body, f.imageBody) {
t.Errorf("image body did not round-trip")
}
if mime := res.MimeType; mime != "image/png" {
t.Errorf("mime = %q", mime)
}
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
t.Errorf("metadata model = %v", got)
}
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
t.Errorf("metadata model_version = %v", got)
}
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
t.Errorf("metadata predict_time_seconds = %v", got)
}
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
}
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
t.Errorf("aspect_ratio = %q", got)
}
if got := f.pollCalls.Load(); got < 3 {
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
row := sink.rows[0]
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
}
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
}
if len(row.PromptHash) != 64 {
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
}
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
}
if row.LatencyMs <= 0 {
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
}
}
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
f := newFakeReplicate(t)
var sentBody []byte
var sentPath string
mu := sync.Mutex{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
mu.Lock()
sentBody, _ = io.ReadAll(r.Body)
sentPath = r.URL.Path
mu.Unlock()
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
default:
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev:abc123",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
if err != nil {
t.Fatalf("Generate: %v", err)
}
res.ImageReader.Close()
mu.Lock()
defer mu.Unlock()
if sentPath != "/v1/predictions" {
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
}
var body map[string]any
if err := json.Unmarshal(sentBody, &body); err != nil {
t.Fatalf("unmarshal sent body: %v", err)
}
if body["version"] != "abc123" {
t.Errorf("expected version=abc123 in body, got %v", body["version"])
}
}
func TestReplicate401SurfacesEnvHint(t *testing.T) {
f := newFakeReplicate(t)
f.auth401 = true
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected 401 to surface as error")
}
msg := err.Error()
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name env var: %v", err)
}
if !strings.Contains(msg, "401") {
t.Errorf("error should mention 401: %v", err)
}
}
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 2 // first two calls 429 then succeed
f.retryAfter = "" // force the adapter's exp-backoff path
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = time.Millisecond
// Squash the backoff so the test is fast.
r.httpClient = &http.Client{Timeout: 5 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := r.Generate(ctx, Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate after 429s: %v", err)
}
res.ImageReader.Close()
if got := f.createCalls.Load(); got != 3 {
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
}
}
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 99 // every call 429
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained 429s")
}
if !strings.Contains(err.Error(), "429") {
t.Errorf("expected 429 in error: %v", err)
}
// max429Retries=3 → 1 initial + 3 retries = 4 total
if got := f.createCalls.Load(); got != 4 {
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
}
}
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error for failed prediction")
}
if !strings.Contains(err.Error(), "failed") {
t.Errorf("error should mention failure: %v", err)
}
if !strings.Contains(err.Error(), "NSFW") {
t.Errorf("error should include the API error message: %v", err)
}
}
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
f := newFakeReplicate(t)
// Always return processing → adapter times out.
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 30 * time.Millisecond
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected timeout error")
}
if !strings.Contains(err.Error(), "did not complete") {
t.Errorf("expected 'did not complete' in error, got %v", err)
}
if !strings.Contains(err.Error(), "waited") {
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
}
}
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 1 // first download 502, second OK
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (download retry): %v", err)
}
res.ImageReader.Close()
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
}
}
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 99 // every download fails
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained image-download 5xx")
}
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
}
}
func TestReplicateContextCancelStopsPolling(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected ctx error")
}
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
t.Errorf("expected deadline exceeded, got %v", err)
}
}
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-schnell",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
// captured input should still have been recorded.
_, _ = r.Generate(context.Background(), Request{
Prompt: "p",
Width: 1024, Height: 1024,
BackendOpts: map[string]any{
"output_quality": 90,
"go_fast": true,
},
})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["output_quality"] != float64(90) {
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
}
if captured["go_fast"] != true {
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
}
if captured["prompt"] != "p" {
t.Errorf("prompt not threaded: %v", captured["prompt"])
}
}
func TestReplicateOutputArrayAccepted(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (output as array): %v", err)
}
res.ImageReader.Close()
}
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "stability-ai/sdxl")
sink := &recordingSink{}
r.Sink = sink
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (unknown model): %v", err)
}
res.ImageReader.Close()
if _, present := res.Metadata["cost_usd_estimate"]; present {
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
if sink.rows[0].CostUSDEstimate != nil {
t.Errorf("expected nil cost in sink row for unknown model")
}
}
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("sink failure should not fail Generate: %v", err)
}
res.ImageReader.Close()
}
func TestReplicateDefaultStepsApplied(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev",
"api_base": srv.URL,
"default_steps": 28,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["num_inference_steps"] != float64(28) {
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
}
}
func TestComputeAspectRatio(t *testing.T) {
cases := []struct {
w, h int
fallback string
want string
}{
{1024, 1024, "1:1", "1:1"},
{1920, 1080, "1:1", "16:9"},
{2560, 1440, "1:1", "16:9"},
{1024, 768, "1:1", "4:3"},
{1024, 1280, "1:1", "4:5"},
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
{0, 1024, "1:1", "1:1"},
}
for _, c := range cases {
got := computeAspectRatio(c.w, c.h, c.fallback)
if got != c.want {
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
}
}
}
func TestParseModelRef(t *testing.T) {
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
owner, name, ver, err = parseModelRef("owner/name:hash123")
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
if _, _, _, err := parseModelRef("noslash"); err == nil {
t.Errorf("expected error for malformed ref")
}
}
func TestHashPromptStable(t *testing.T) {
a := hashPrompt("hello")
b := hashPrompt("hello")
c := hashPrompt("hello!")
if a != b {
t.Errorf("hashPrompt should be deterministic")
}
if a == c {
t.Errorf("different prompts should hash differently")
}
if len(a) != 64 {
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
}
}
func TestReplicatePricingKnownModels(t *testing.T) {
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
t.Errorf("dev rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
}
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
t.Errorf("unknown model should report ok=false")
}
}
func TestReplicateTypeIsRegistered(t *testing.T) {
if !Default.Has(ReplicateType) {
t.Errorf("replicate type not registered in Default")
}
}
// recordingSink captures rows for assertion.
type recordingSink struct {
mu sync.Mutex
rows []UsageRow
}
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rows = append(s.rows, row)
return nil
}
type sinkFunc func(context.Context, UsageRow) error
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }

View File

@@ -131,11 +131,19 @@ backends:
mock: mock:
type: mock type: mock
flux-schnell-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-schnell
default_steps: 4
default_aspect_ratio: "1:1"
flux-dev-replicate: flux-dev-replicate:
type: replicate type: replicate
api_token_env: REPLICATE_API_TOKEN api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-dev model: black-forest-labs/flux-dev
default_steps: 28 default_steps: 28
default_aspect_ratio: "1:1"
dalle3: dalle3:
type: openai type: openai

160
internal/usage/usage.go Normal file
View File

@@ -0,0 +1,160 @@
// Package usage records per-call cost-tracking rows for the imagen CLI
// to mai.imagen_usage on Supabase. The writer is best-effort by design —
// the calling adapter logs failures and proceeds, because the image
// itself has already landed on disk by the time we record.
package usage
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
)
// Default REST schema is the mai schema where mai.imagen_usage lives.
const supabaseSchema = "mai"
// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/
// Content-Profile headers to target the mai schema instead of public.
type SupabaseSink struct {
URL string // SUPABASE_URL — e.g. https://msup.msbls.de
APIKey string // SUPABASE_SERVICE_KEY
HTTP *http.Client
}
// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY
// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use.
// Returns nil + ok=false if the env vars are not configured — the CLI
// uses that to skip cost-tracking gracefully.
func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &SupabaseSink{
URL: u,
APIKey: key,
HTTP: &http.Client{Timeout: 10 * time.Second},
}, true
}
type supabaseRow struct {
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed,omitempty"`
PromptHash string `json:"prompt_hash"`
LatencyMs int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"`
Caller string `json:"caller,omitempty"`
}
// Record inserts one row into mai.imagen_usage.
func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error {
body, err := json.Marshal(supabaseRow{
Backend: row.Backend,
Model: row.Model,
Seed: row.Seed,
PromptHash: row.PromptHash,
LatencyMs: row.LatencyMs,
CostUSDEstimate: row.CostUSDEstimate,
Caller: row.Caller,
})
if err != nil {
return fmt.Errorf("usage: marshal: %w", err)
}
endpoint := s.URL + "/rest/v1/imagen_usage"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=minimal")
resp, err := s.HTTP.Do(req)
if err != nil {
return fmt.Errorf("usage: POST: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody))
}
return nil
}
// Row is the read-side row shape (only the fields the CLI needs).
type Row struct {
CreatedAt time.Time `json:"created_at"`
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed"`
PromptHash string `json:"prompt_hash"`
LatencyMs *int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate"`
Caller *string `json:"caller"`
}
// Query returns rows from mai.imagen_usage filtered by created_at >= since.
// Pass zero time to fetch the full table (capped server-side by PostgREST
// — we set a hard 5000-row limit here too).
func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) {
q := url.Values{}
q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller")
q.Set("order", "created_at.desc")
q.Set("limit", "5000")
if !since.IsZero() {
q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339))
}
endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Accept-Profile", supabaseSchema)
resp, err := s.HTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("usage: GET: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body))
}
var rows []Row
if err := json.Unmarshal(body, &rows); err != nil {
return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body))
}
return rows, nil
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}