Files
ImaGen/internal/backend/replicate.go
mAi b282325663 mAi: #3 - Replicate adapter, mai.imagen_usage cost-tracking, usage CLI
Implements the Replicate API backend (FLUX schnell / FLUX dev) per ImaGen
issue #3:

- internal/backend/replicate.go — Backend adapter. Supports model
  refs as "owner/name" (uses /v1/models/{owner}/{name}/predictions) and
  "owner/name:hash" (uses /v1/predictions with explicit version). Polls
  /v1/predictions/{id} every 500ms with model-aware timeout (60s schnell,
  120s dev). Resilience: 401 names api_token_env, 429 with exp backoff
  up to 3 retries (honours Retry-After), 5xx retries once, image
  download retries once on transient failure.
- internal/backend/replicate_pricing.go — hardcoded per-image USD rates
  for known FLUX models, snapshotted from replicate.com/pricing with a
  refresh TODO.
- internal/backend/replicate_test.go — mocked-HTTP unit tests covering
  happy path (model + version-pinned), 401, 429 retry policy, failed
  prediction, poll timeout, image-download retry, ctx cancel, BackendOpts
  passthrough, default_steps, aspect-ratio reduction, sha256 prompt hash.
- internal/usage/usage.go — Supabase REST sink + read-side query for
  mai.imagen_usage. Adapter writes are best-effort: failures warn but
  the image still lands.
- cmd/imagen/usage.go — `imagen usage [--since DATE] [--raw]` reads
  the table and prints a tab-aligned grouped or raw table with totals.
- cmd/imagen/backends.go — instances of type=replicate now report
  "ok" or "not configured (set REPLICATE_API_TOKEN)" depending on env.
- internal/config/config.go — sample adds flux-schnell-replicate +
  flux-dev-replicate; default_backend stays flux-schnell-local.
- Supabase migration mai.imagen_usage (id, created_at, backend, model,
  seed, prompt_hash, latency_ms, cost_usd_estimate, caller) + indexes
  on (created_at DESC) and (caller). The raw prompt is never stored.

Caller identity resolves from MAI_FROM_ID, then the tmux pane's
@mai-name option, mirroring the maimcp identity logic. Prompt hash is
sha256 of the user-facing prompt; raw prompt never reaches the table.
2026-05-08 17:28:29 +02:00

568 lines
16 KiB
Go

package backend
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"time"
)
// ReplicateType is the type-name adapters register under for Replicate
// instances.
const ReplicateType = "replicate"
// Replicate is the Replicate API adapter. It speaks the public REST API
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
// to submit, then polls /v1/predictions/{id} until the prediction
// settles, then downloads the produced image.
//
// Concurrency: a single Replicate is safe to share across goroutines.
type Replicate struct {
instance string
apiBase string
apiToken string
tokenEnv string
model string // "owner/name" or "owner/name:version-hash"
owner string
name string
version string // optional; empty means "use the model-based predictions endpoint"
defaultSteps int
defaultAspect string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
initialBackoff time.Duration
// Sink is where successful generations are recorded for cost-tracking.
// nil means "do not record". The framework wires this up in the CLI;
// adapter tests inject a fake.
Sink UsageSink
}
// UsageSink writes one row per successful generation. Implementations
// should treat write failures as warnings, not errors — the image has
// already landed on disk; failing the call would lose the artefact.
type UsageSink interface {
Record(ctx context.Context, row UsageRow) error
}
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
// prompt itself is intentionally NOT included — only the sha256 hash.
type UsageRow struct {
Backend string
Model string
Seed *int64
PromptHash string
LatencyMs int
CostUSDEstimate *float64
Caller string
}
// NewReplicate is the registry constructor. cfg is the adapter's slice
// of imagen.yaml.
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("replicate: empty instance name")
}
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
}
owner, modelName, version, err := parseModelRef(model)
if err != nil {
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
}
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
pollTimeout := timeoutForModel(modelName)
r := &Replicate{
instance: name,
apiBase: apiBase,
tokenEnv: tokenEnv,
apiToken: os.Getenv(tokenEnv),
model: model,
owner: owner,
name: modelName,
version: version,
defaultSteps: getInt(cfg, "default_steps", 0),
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
httpClient: &http.Client{Timeout: 30 * time.Second},
pollInterval: 500 * time.Millisecond,
pollTimeout: pollTimeout,
randSeed: cryptoSeed,
initialBackoff: time.Second,
}
return r, nil
}
// Name returns the instance name from imagen.yaml.
func (r *Replicate) Name() string { return r.instance }
// Generate submits one prediction and returns the resulting PNG.
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
if r.apiToken == "" {
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
}
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
aspect := computeAspectRatio(width, height, r.defaultAspect)
steps := orDefaultInt(req.Steps, r.defaultSteps)
seed := req.Seed
if seed == 0 {
seed = r.randSeed()
}
input := map[string]any{
"prompt": req.Prompt,
"aspect_ratio": aspect,
"num_outputs": 1,
"output_format": "png",
"seed": seed,
}
if steps > 0 {
input["num_inference_steps"] = steps
}
if req.NegativePrompt != "" {
input["negative_prompt"] = req.NegativePrompt
}
maps.Copy(input, req.BackendOpts)
start := time.Now()
pred, err := r.submitPrediction(ctx, input)
if err != nil {
return nil, err
}
final, err := r.waitForCompletion(ctx, pred.ID)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgURL, err := pickFirstOutputURL(final.Output)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
if err != nil {
return nil, err
}
latencyMs := int(time.Since(start).Milliseconds())
predictTime := final.Metrics.PredictTime
costEst, costKnown := replicatePerImageUSD(r.model)
meta := map[string]any{
"backend": r.instance,
"backend_type": ReplicateType,
"model": r.model,
"model_version": final.Version,
"prediction_id": pred.ID,
"seed": seed,
"width": width,
"height": height,
"aspect_ratio": aspect,
"latency_ms": int64(latencyMs),
}
if predictTime > 0 {
meta["predict_time_seconds"] = predictTime
}
if costKnown {
meta["cost_usd_estimate"] = costEst
}
if r.Sink != nil {
row := UsageRow{
Backend: r.instance,
Model: r.model,
PromptHash: hashPrompt(req.Prompt),
LatencyMs: latencyMs,
Caller: ResolveCaller(),
}
row.Seed = new(int64)
*row.Seed = seed
if costKnown {
c := costEst
row.CostUSDEstimate = &c
}
if err := r.Sink.Record(ctx, row); err != nil {
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
}
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: mime,
Metadata: meta,
}, nil
}
// replicatePrediction is what the REST API returns under /v1/predictions
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
// declared.
type replicatePrediction struct {
ID string `json:"id"`
Status string `json:"status"`
Version string `json:"version"`
Error json.RawMessage `json:"error"`
Output json.RawMessage `json:"output"`
Metrics struct {
PredictTime float64 `json:"predict_time"`
} `json:"metrics"`
}
// submitPrediction creates a prediction and returns it. Picks the
// model-based endpoint when no version was given (recommended for
// Replicate's official models), and the legacy /v1/predictions otherwise.
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
var (
reqURL string
body []byte
err error
)
if r.version == "" {
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
body, err = json.Marshal(map[string]any{"input": input})
} else {
reqURL = r.apiBase + "/v1/predictions"
body, err = json.Marshal(map[string]any{
"version": r.version,
"input": input,
})
}
if err != nil {
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
}
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(respBody, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
}
if pred.ID == "" {
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
}
return &pred, nil
}
// waitForCompletion polls /v1/predictions/{id} until the status is a
// terminal value (succeeded, failed, canceled) or the timeout fires.
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
deadline := time.Now().Add(r.pollTimeout)
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
start := time.Now()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if time.Now().After(deadline) {
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
}
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(body, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
}
switch pred.Status {
case "succeeded":
return &pred, nil
case "failed":
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
case "canceled":
return nil, fmt.Errorf("status=canceled")
case "starting", "processing", "":
default:
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(r.pollInterval):
}
}
}
// doWithRetry executes a Replicate API request and applies the resilience
// policy: 401 surfaces a clean message naming the env var; 429 retries
// with exponential backoff up to three times; 5xx retries once; other
// errors surface unchanged.
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
const max429Retries = 3
backoff := r.initialBackoff
if backoff <= 0 {
backoff = time.Second
}
var lastErr error
for attempt := 0; ; attempt++ {
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Token "+r.apiToken)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
}
lastErr = err
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return respBody, nil
case resp.StatusCode == http.StatusUnauthorized:
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
case resp.StatusCode == http.StatusTooManyRequests:
if attempt >= max429Retries {
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
}
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
if !sleepCtx(ctx, wait) {
return nil, ctx.Err()
}
backoff *= 2
continue
case resp.StatusCode >= 500:
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
default:
_ = lastErr
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
}
}
// fetchImage downloads the rendered image from the Replicate-provided
// CDN URL. One retry on a generic network error.
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
var lastErr error
for attempt := range 2 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
if err != nil {
return nil, "", err
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, "", err
}
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
lastErr = readErr
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if resp.StatusCode >= 500 && attempt == 0 {
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
continue
}
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
}
mime := resp.Header.Get("Content-Type")
if mime == "" {
mime = "image/png"
}
return body, mime, nil
}
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
}
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
func parseModelRef(ref string) (owner, name, version string, err error) {
rest := ref
if i := strings.IndexByte(rest, ':'); i >= 0 {
version = rest[i+1:]
rest = rest[:i]
}
parts := strings.SplitN(rest, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
}
return parts[0], parts[1], version, nil
}
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
// than schnell; everything else gets the dev timeout for safety.
func timeoutForModel(name string) time.Duration {
switch strings.ToLower(name) {
case "flux-schnell":
return 60 * time.Second
default:
return 120 * time.Second
}
}
// pickFirstOutputURL extracts the first image URL from the output field.
// Replicate returns either a single string or an array of strings for image
// models — we accept both.
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", fmt.Errorf("output is empty")
}
var s string
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
return s, nil
}
var arr []string
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
return arr[0], nil
}
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
}
// computeAspectRatio reduces width:height to a Replicate-supported aspect
// ratio when the reduction lands on one of the canonical values; otherwise
// returns fallback.
func computeAspectRatio(w, h int, fallback string) string {
if w <= 0 || h <= 0 {
return fallback
}
g := gcd(w, h)
a, b := w/g, h/g
s := fmt.Sprintf("%d:%d", a, b)
if isReplicateAspectRatio(s) {
return s
}
return fallback
}
func isReplicateAspectRatio(s string) bool {
switch s {
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
return true
}
return false
}
func gcd(a, b int) int {
for b != 0 {
a, b = b, a%b
}
if a < 0 {
return -a
}
return a
}
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
// is intentionally never written to the cost-tracking table.
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// ResolveCaller returns the agent identity the cost-tracking row is
// attributed to. Order of resolution mirrors the maimcp identity logic:
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
// "unknown".
func ResolveCaller() string {
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
return v
}
if pane := os.Getenv("TMUX_PANE"); pane != "" {
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
if err == nil {
if name := strings.TrimSpace(string(out)); name != "" {
return name
}
}
}
return "unknown"
}
// backoffFor429 honours a Retry-After header (in seconds) when present
// and within reason, otherwise falls back to the caller's backoff.
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
if retryAfter == "" {
return fallback
}
d, err := time.ParseDuration(retryAfter + "s")
if err != nil || d <= 0 {
return fallback
}
if d > 30*time.Second {
return 30 * time.Second
}
return d
}
func sleepCtx(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
func bytesReader(b []byte) io.Reader {
if b == nil {
return nil
}
return bytes.NewReader(b)
}
// shortPath strips the host so error messages don't leak the API base.
func shortPath(u string) string {
if i := strings.Index(u, "/v1/"); i >= 0 {
return u[i:]
}
return u
}
func init() {
Register(ReplicateType, NewReplicate)
}