Implements the Replicate API backend (FLUX schnell / FLUX dev) per ImaGen issue #3: - internal/backend/replicate.go — Backend adapter. Supports model refs as "owner/name" (uses /v1/models/{owner}/{name}/predictions) and "owner/name:hash" (uses /v1/predictions with explicit version). Polls /v1/predictions/{id} every 500ms with model-aware timeout (60s schnell, 120s dev). Resilience: 401 names api_token_env, 429 with exp backoff up to 3 retries (honours Retry-After), 5xx retries once, image download retries once on transient failure. - internal/backend/replicate_pricing.go — hardcoded per-image USD rates for known FLUX models, snapshotted from replicate.com/pricing with a refresh TODO. - internal/backend/replicate_test.go — mocked-HTTP unit tests covering happy path (model + version-pinned), 401, 429 retry policy, failed prediction, poll timeout, image-download retry, ctx cancel, BackendOpts passthrough, default_steps, aspect-ratio reduction, sha256 prompt hash. - internal/usage/usage.go — Supabase REST sink + read-side query for mai.imagen_usage. Adapter writes are best-effort: failures warn but the image still lands. - cmd/imagen/usage.go — `imagen usage [--since DATE] [--raw]` reads the table and prints a tab-aligned grouped or raw table with totals. - cmd/imagen/backends.go — instances of type=replicate now report "ok" or "not configured (set REPLICATE_API_TOKEN)" depending on env. - internal/config/config.go — sample adds flux-schnell-replicate + flux-dev-replicate; default_backend stays flux-schnell-local. - Supabase migration mai.imagen_usage (id, created_at, backend, model, seed, prompt_hash, latency_ms, cost_usd_estimate, caller) + indexes on (created_at DESC) and (caller). The raw prompt is never stored. Caller identity resolves from MAI_FROM_ID, then the tmux pane's @mai-name option, mirroring the maimcp identity logic. Prompt hash is sha256 of the user-facing prompt; raw prompt never reaches the table.
676 lines
21 KiB
Go
676 lines
21 KiB
Go
package backend
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// fakeReplicate is a programmable mock of the Replicate REST API.
|
|
type fakeReplicate struct {
|
|
t *testing.T
|
|
|
|
mu sync.Mutex
|
|
|
|
// Responses for the predictions submission. If the request matches the
|
|
// model-based path, modelEndpointHits increments; for /v1/predictions,
|
|
// versionEndpointHits increments.
|
|
createStatus int
|
|
createBody []byte
|
|
createCalls atomic.Int32
|
|
|
|
// Auth-fail policy: if true, every request to the prediction endpoints
|
|
// returns 401 with a stock body.
|
|
auth401 bool
|
|
|
|
// 429-then-OK policy: first N create calls return 429.
|
|
create429Until int32
|
|
retryAfter string
|
|
|
|
// Sequence of /predictions/{id} responses, walked in order; once
|
|
// exhausted, the last entry is returned indefinitely.
|
|
pollResponses []string
|
|
pollIdx atomic.Int32
|
|
pollCalls atomic.Int32
|
|
|
|
// Image download policy.
|
|
imageStatus int
|
|
imageBody []byte
|
|
image5xxFirst int32
|
|
imageCalls atomic.Int32
|
|
|
|
server *httptest.Server
|
|
}
|
|
|
|
func (f *fakeReplicate) handler() http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
|
|
f.handleCreate(w, r)
|
|
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
|
|
f.handleCreate(w, r)
|
|
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
|
|
f.handlePoll(w, r)
|
|
case r.Method == http.MethodGet && r.URL.Path == "/img":
|
|
f.handleImage(w, r)
|
|
default:
|
|
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
|
|
http.NotFound(w, r)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
|
|
n := f.createCalls.Add(1)
|
|
if f.auth401 {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
|
|
return
|
|
}
|
|
if n <= f.create429Until {
|
|
if f.retryAfter != "" {
|
|
w.Header().Set("Retry-After", f.retryAfter)
|
|
}
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
|
|
return
|
|
}
|
|
w.WriteHeader(f.createStatus)
|
|
_, _ = w.Write(f.createBody)
|
|
}
|
|
|
|
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
|
|
f.pollCalls.Add(1)
|
|
idx := int(f.pollIdx.Add(1)) - 1
|
|
if idx >= len(f.pollResponses) {
|
|
idx = len(f.pollResponses) - 1
|
|
}
|
|
if idx < 0 {
|
|
http.Error(w, "no poll response configured", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(f.pollResponses[idx]))
|
|
}
|
|
|
|
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
|
|
n := f.imageCalls.Add(1)
|
|
if n <= f.image5xxFirst {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
_, _ = w.Write([]byte("upstream unavailable"))
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "image/png")
|
|
w.WriteHeader(f.imageStatus)
|
|
_, _ = w.Write(f.imageBody)
|
|
}
|
|
|
|
func (f *fakeReplicate) start() {
|
|
f.server = httptest.NewServer(f.handler())
|
|
f.t.Cleanup(f.server.Close)
|
|
}
|
|
|
|
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
|
|
|
|
func newFakeReplicate(t *testing.T) *fakeReplicate {
|
|
t.Helper()
|
|
f := &fakeReplicate{
|
|
t: t,
|
|
createStatus: http.StatusCreated,
|
|
imageStatus: http.StatusOK,
|
|
imageBody: mustPNG(t, 8, 8),
|
|
}
|
|
f.start()
|
|
// Default happy-path responses now that the server URL is known.
|
|
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
|
|
f.pollResponses = []string{
|
|
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
|
|
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
|
|
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
|
|
}
|
|
return f
|
|
}
|
|
|
|
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
|
|
t.Helper()
|
|
be, err := NewReplicate("flux-test", map[string]any{
|
|
"api_token_env": "TEST_REPLICATE_TOKEN",
|
|
"model": model,
|
|
"api_base": f.server.URL,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewReplicate: %v", err)
|
|
}
|
|
r := be.(*Replicate)
|
|
r.apiToken = "fake-token"
|
|
r.pollInterval = time.Millisecond
|
|
r.pollTimeout = 5 * time.Second
|
|
r.initialBackoff = time.Millisecond
|
|
r.randSeed = func() int64 { return 42 }
|
|
return r
|
|
}
|
|
|
|
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
|
|
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
|
|
t.Errorf("expected error for empty instance name")
|
|
}
|
|
if _, err := NewReplicate("x", map[string]any{}); err == nil {
|
|
t.Errorf("expected error for missing model")
|
|
}
|
|
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
|
|
t.Errorf("expected error for malformed model")
|
|
}
|
|
}
|
|
|
|
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
r.apiToken = "" // simulate the env var being unset
|
|
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected error when token is missing")
|
|
}
|
|
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
|
|
t.Errorf("error should name the env var: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
sink := &recordingSink{}
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
r.Sink = sink
|
|
|
|
res, err := r.Generate(context.Background(), Request{
|
|
Prompt: "a tiny dragon",
|
|
Width: 1024, Height: 1024,
|
|
Seed: 0,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
defer res.ImageReader.Close()
|
|
body, _ := io.ReadAll(res.ImageReader)
|
|
if !bytes.Equal(body, f.imageBody) {
|
|
t.Errorf("image body did not round-trip")
|
|
}
|
|
|
|
if mime := res.MimeType; mime != "image/png" {
|
|
t.Errorf("mime = %q", mime)
|
|
}
|
|
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
|
|
t.Errorf("metadata model = %v", got)
|
|
}
|
|
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
|
|
t.Errorf("metadata model_version = %v", got)
|
|
}
|
|
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
|
|
t.Errorf("metadata predict_time_seconds = %v", got)
|
|
}
|
|
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
|
|
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
|
|
}
|
|
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
|
|
t.Errorf("aspect_ratio = %q", got)
|
|
}
|
|
|
|
if got := f.pollCalls.Load(); got < 3 {
|
|
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
|
|
}
|
|
|
|
if len(sink.rows) != 1 {
|
|
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
|
|
}
|
|
row := sink.rows[0]
|
|
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
|
|
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
|
|
}
|
|
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
|
|
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
|
|
}
|
|
if len(row.PromptHash) != 64 {
|
|
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
|
|
}
|
|
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
|
|
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
|
|
}
|
|
if row.LatencyMs <= 0 {
|
|
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
|
|
}
|
|
}
|
|
|
|
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
|
|
var sentBody []byte
|
|
var sentPath string
|
|
mu := sync.Mutex{}
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
|
|
mu.Lock()
|
|
sentBody, _ = io.ReadAll(r.Body)
|
|
sentPath = r.URL.Path
|
|
mu.Unlock()
|
|
w.WriteHeader(http.StatusCreated)
|
|
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
|
|
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
|
|
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
|
|
default:
|
|
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
|
|
be, err := NewReplicate("flux-test", map[string]any{
|
|
"model": "black-forest-labs/flux-dev:abc123",
|
|
"api_base": srv.URL,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewReplicate: %v", err)
|
|
}
|
|
r := be.(*Replicate)
|
|
r.apiToken = "fake"
|
|
r.pollInterval = time.Millisecond
|
|
|
|
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if sentPath != "/v1/predictions" {
|
|
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
|
|
}
|
|
var body map[string]any
|
|
if err := json.Unmarshal(sentBody, &body); err != nil {
|
|
t.Fatalf("unmarshal sent body: %v", err)
|
|
}
|
|
if body["version"] != "abc123" {
|
|
t.Errorf("expected version=abc123 in body, got %v", body["version"])
|
|
}
|
|
}
|
|
|
|
func TestReplicate401SurfacesEnvHint(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.auth401 = true
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
|
|
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected 401 to surface as error")
|
|
}
|
|
msg := err.Error()
|
|
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
|
|
t.Errorf("error should name env var: %v", err)
|
|
}
|
|
if !strings.Contains(msg, "401") {
|
|
t.Errorf("error should mention 401: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.create429Until = 2 // first two calls 429 then succeed
|
|
f.retryAfter = "" // force the adapter's exp-backoff path
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
r.pollInterval = time.Millisecond
|
|
|
|
// Squash the backoff so the test is fast.
|
|
r.httpClient = &http.Client{Timeout: 5 * time.Second}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
res, err := r.Generate(ctx, Request{Prompt: "p"})
|
|
if err != nil {
|
|
t.Fatalf("Generate after 429s: %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
if got := f.createCalls.Load(); got != 3 {
|
|
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.create429Until = 99 // every call 429
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
_, err := r.Generate(ctx, Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected error after sustained 429s")
|
|
}
|
|
if !strings.Contains(err.Error(), "429") {
|
|
t.Errorf("expected 429 in error: %v", err)
|
|
}
|
|
// max429Retries=3 → 1 initial + 3 retries = 4 total
|
|
if got := f.createCalls.Load(); got != 4 {
|
|
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.pollResponses = []string{
|
|
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
|
|
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
|
|
}
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected error for failed prediction")
|
|
}
|
|
if !strings.Contains(err.Error(), "failed") {
|
|
t.Errorf("error should mention failure: %v", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "NSFW") {
|
|
t.Errorf("error should include the API error message: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
// Always return processing → adapter times out.
|
|
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
r.pollInterval = 5 * time.Millisecond
|
|
r.pollTimeout = 30 * time.Millisecond
|
|
|
|
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected timeout error")
|
|
}
|
|
if !strings.Contains(err.Error(), "did not complete") {
|
|
t.Errorf("expected 'did not complete' in error, got %v", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "waited") {
|
|
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.image5xxFirst = 1 // first download 502, second OK
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
|
|
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err != nil {
|
|
t.Fatalf("Generate (download retry): %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
if got := f.imageCalls.Load(); got != 2 {
|
|
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.image5xxFirst = 99 // every download fails
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
|
|
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected error after sustained image-download 5xx")
|
|
}
|
|
if got := f.imageCalls.Load(); got != 2 {
|
|
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestReplicateContextCancelStopsPolling(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
r.pollInterval = 5 * time.Millisecond
|
|
r.pollTimeout = 5 * time.Second
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
|
defer cancel()
|
|
_, err := r.Generate(ctx, Request{Prompt: "p"})
|
|
if err == nil {
|
|
t.Fatal("expected ctx error")
|
|
}
|
|
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
|
|
t.Errorf("expected deadline exceeded, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
|
|
var captured map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
|
|
body, _ := io.ReadAll(r.Body)
|
|
var top struct {
|
|
Input map[string]any `json:"input"`
|
|
}
|
|
_ = json.Unmarshal(body, &top)
|
|
captured = top.Input
|
|
w.WriteHeader(http.StatusCreated)
|
|
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
|
|
}
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
|
|
be, err := NewReplicate("flux-test", map[string]any{
|
|
"model": "black-forest-labs/flux-schnell",
|
|
"api_base": srv.URL,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewReplicate: %v", err)
|
|
}
|
|
r := be.(*Replicate)
|
|
r.apiToken = "fake"
|
|
r.pollInterval = time.Millisecond
|
|
|
|
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
|
|
// captured input should still have been recorded.
|
|
_, _ = r.Generate(context.Background(), Request{
|
|
Prompt: "p",
|
|
Width: 1024, Height: 1024,
|
|
BackendOpts: map[string]any{
|
|
"output_quality": 90,
|
|
"go_fast": true,
|
|
},
|
|
})
|
|
if captured == nil {
|
|
t.Fatal("create endpoint not hit")
|
|
}
|
|
if captured["output_quality"] != float64(90) {
|
|
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
|
|
}
|
|
if captured["go_fast"] != true {
|
|
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
|
|
}
|
|
if captured["prompt"] != "p" {
|
|
t.Errorf("prompt not threaded: %v", captured["prompt"])
|
|
}
|
|
}
|
|
|
|
func TestReplicateOutputArrayAccepted(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
f.pollResponses = []string{
|
|
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
|
|
}
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err != nil {
|
|
t.Fatalf("Generate (output as array): %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
}
|
|
|
|
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
r := newReplicate(t, f, "stability-ai/sdxl")
|
|
sink := &recordingSink{}
|
|
r.Sink = sink
|
|
|
|
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err != nil {
|
|
t.Fatalf("Generate (unknown model): %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
if _, present := res.Metadata["cost_usd_estimate"]; present {
|
|
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
|
|
}
|
|
if len(sink.rows) != 1 {
|
|
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
|
|
}
|
|
if sink.rows[0].CostUSDEstimate != nil {
|
|
t.Errorf("expected nil cost in sink row for unknown model")
|
|
}
|
|
}
|
|
|
|
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
|
|
f := newFakeReplicate(t)
|
|
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
|
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
|
|
|
|
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if err != nil {
|
|
t.Fatalf("sink failure should not fail Generate: %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
}
|
|
|
|
func TestReplicateDefaultStepsApplied(t *testing.T) {
|
|
var captured map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
|
|
body, _ := io.ReadAll(r.Body)
|
|
var top struct {
|
|
Input map[string]any `json:"input"`
|
|
}
|
|
_ = json.Unmarshal(body, &top)
|
|
captured = top.Input
|
|
w.WriteHeader(http.StatusCreated)
|
|
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
|
|
}
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
|
|
be, err := NewReplicate("flux-test", map[string]any{
|
|
"model": "black-forest-labs/flux-dev",
|
|
"api_base": srv.URL,
|
|
"default_steps": 28,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewReplicate: %v", err)
|
|
}
|
|
r := be.(*Replicate)
|
|
r.apiToken = "fake"
|
|
r.pollInterval = time.Millisecond
|
|
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
|
|
if captured == nil {
|
|
t.Fatal("create endpoint not hit")
|
|
}
|
|
if captured["num_inference_steps"] != float64(28) {
|
|
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
|
|
}
|
|
}
|
|
|
|
func TestComputeAspectRatio(t *testing.T) {
|
|
cases := []struct {
|
|
w, h int
|
|
fallback string
|
|
want string
|
|
}{
|
|
{1024, 1024, "1:1", "1:1"},
|
|
{1920, 1080, "1:1", "16:9"},
|
|
{2560, 1440, "1:1", "16:9"},
|
|
{1024, 768, "1:1", "4:3"},
|
|
{1024, 1280, "1:1", "4:5"},
|
|
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
|
|
{0, 1024, "1:1", "1:1"},
|
|
}
|
|
for _, c := range cases {
|
|
got := computeAspectRatio(c.w, c.h, c.fallback)
|
|
if got != c.want {
|
|
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseModelRef(t *testing.T) {
|
|
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
|
|
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
|
|
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
|
|
}
|
|
owner, name, ver, err = parseModelRef("owner/name:hash123")
|
|
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
|
|
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
|
|
}
|
|
if _, _, _, err := parseModelRef("noslash"); err == nil {
|
|
t.Errorf("expected error for malformed ref")
|
|
}
|
|
}
|
|
|
|
func TestHashPromptStable(t *testing.T) {
|
|
a := hashPrompt("hello")
|
|
b := hashPrompt("hello")
|
|
c := hashPrompt("hello!")
|
|
if a != b {
|
|
t.Errorf("hashPrompt should be deterministic")
|
|
}
|
|
if a == c {
|
|
t.Errorf("different prompts should hash differently")
|
|
}
|
|
if len(a) != 64 {
|
|
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
|
|
}
|
|
}
|
|
|
|
func TestReplicatePricingKnownModels(t *testing.T) {
|
|
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
|
|
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
|
|
}
|
|
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
|
|
t.Errorf("dev rate = %v (ok=%v)", v, ok)
|
|
}
|
|
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
|
|
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
|
|
}
|
|
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
|
|
t.Errorf("unknown model should report ok=false")
|
|
}
|
|
}
|
|
|
|
func TestReplicateTypeIsRegistered(t *testing.T) {
|
|
if !Default.Has(ReplicateType) {
|
|
t.Errorf("replicate type not registered in Default")
|
|
}
|
|
}
|
|
|
|
// recordingSink captures rows for assertion.
|
|
type recordingSink struct {
|
|
mu sync.Mutex
|
|
rows []UsageRow
|
|
}
|
|
|
|
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.rows = append(s.rows, row)
|
|
return nil
|
|
}
|
|
|
|
type sinkFunc func(context.Context, UsageRow) error
|
|
|
|
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }
|