package backend import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" ) // fakeReplicate is a programmable mock of the Replicate REST API. type fakeReplicate struct { t *testing.T mu sync.Mutex // Responses for the predictions submission. If the request matches the // model-based path, modelEndpointHits increments; for /v1/predictions, // versionEndpointHits increments. createStatus int createBody []byte createCalls atomic.Int32 // Auth-fail policy: if true, every request to the prediction endpoints // returns 401 with a stock body. auth401 bool // 429-then-OK policy: first N create calls return 429. create429Until int32 retryAfter string // Sequence of /predictions/{id} responses, walked in order; once // exhausted, the last entry is returned indefinitely. pollResponses []string pollIdx atomic.Int32 pollCalls atomic.Int32 // Image download policy. imageStatus int imageBody []byte image5xxFirst int32 imageCalls atomic.Int32 server *httptest.Server } func (f *fakeReplicate) handler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"): f.handleCreate(w, r) case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions": f.handleCreate(w, r) case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"): f.handlePoll(w, r) case r.Method == http.MethodGet && r.URL.Path == "/img": f.handleImage(w, r) default: f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path) http.NotFound(w, r) } }) } func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) { n := f.createCalls.Add(1) if f.auth401 { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{"detail":"Invalid token"}`)) return } if n <= f.create429Until { if f.retryAfter != "" { w.Header().Set("Retry-After", f.retryAfter) } w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"detail":"too many requests"}`)) return } w.WriteHeader(f.createStatus) _, _ = w.Write(f.createBody) } func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) { f.pollCalls.Add(1) idx := int(f.pollIdx.Add(1)) - 1 if idx >= len(f.pollResponses) { idx = len(f.pollResponses) - 1 } if idx < 0 { http.Error(w, "no poll response configured", http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(f.pollResponses[idx])) } func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) { n := f.imageCalls.Add(1) if n <= f.image5xxFirst { w.WriteHeader(http.StatusBadGateway) _, _ = w.Write([]byte("upstream unavailable")) return } w.Header().Set("Content-Type", "image/png") w.WriteHeader(f.imageStatus) _, _ = w.Write(f.imageBody) } func (f *fakeReplicate) start() { f.server = httptest.NewServer(f.handler()) f.t.Cleanup(f.server.Close) } func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" } func newFakeReplicate(t *testing.T) *fakeReplicate { t.Helper() f := &fakeReplicate{ t: t, createStatus: http.StatusCreated, imageStatus: http.StatusOK, imageBody: mustPNG(t, 8, 8), } f.start() // Default happy-path responses now that the server URL is known. f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`) f.pollResponses = []string{ `{"id":"pred-abc","status":"starting","version":"v1","output":null}`, `{"id":"pred-abc","status":"processing","version":"v1","output":null}`, fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()), } return f } func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate { t.Helper() be, err := NewReplicate("flux-test", map[string]any{ "api_token_env": "TEST_REPLICATE_TOKEN", "model": model, "api_base": f.server.URL, }) if err != nil { t.Fatalf("NewReplicate: %v", err) } r := be.(*Replicate) r.apiToken = "fake-token" r.pollInterval = time.Millisecond r.pollTimeout = 5 * time.Second r.initialBackoff = time.Millisecond r.randSeed = func() int64 { return 42 } return r } func TestReplicateConstructorRejectsBadInputs(t *testing.T) { if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil { t.Errorf("expected error for empty instance name") } if _, err := NewReplicate("x", map[string]any{}); err == nil { t.Errorf("expected error for missing model") } if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil { t.Errorf("expected error for malformed model") } } func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) { f := newFakeReplicate(t) r := newReplicate(t, f, "black-forest-labs/flux-schnell") r.apiToken = "" // simulate the env var being unset _, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err == nil { t.Fatal("expected error when token is missing") } if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") { t.Errorf("error should name the env var: %v", err) } } func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) { f := newFakeReplicate(t) sink := &recordingSink{} r := newReplicate(t, f, "black-forest-labs/flux-schnell") r.Sink = sink res, err := r.Generate(context.Background(), Request{ Prompt: "a tiny dragon", Width: 1024, Height: 1024, Seed: 0, }) if err != nil { t.Fatalf("Generate: %v", err) } defer res.ImageReader.Close() body, _ := io.ReadAll(res.ImageReader) if !bytes.Equal(body, f.imageBody) { t.Errorf("image body did not round-trip") } if mime := res.MimeType; mime != "image/png" { t.Errorf("mime = %q", mime) } if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" { t.Errorf("metadata model = %v", got) } if got, _ := res.Metadata["model_version"].(string); got != "v1" { t.Errorf("metadata model_version = %v", got) } if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 { t.Errorf("metadata predict_time_seconds = %v", got) } if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 { t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok) } if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" { t.Errorf("aspect_ratio = %q", got) } if got := f.pollCalls.Load(); got < 3 { t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got) } if len(sink.rows) != 1 { t.Fatalf("expected 1 sink row, got %d", len(sink.rows)) } row := sink.rows[0] if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" { t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model) } if row.PromptHash == "" || row.PromptHash == "a tiny dragon" { t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash) } if len(row.PromptHash) != 64 { t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash)) } if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 { t.Errorf("sink row cost = %v", row.CostUSDEstimate) } if row.LatencyMs <= 0 { t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs) } } func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) { f := newFakeReplicate(t) var sentBody []byte var sentPath string mu := sync.Mutex{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions": mu.Lock() sentBody, _ = io.ReadAll(r.Body) sentPath = r.URL.Path mu.Unlock() w.WriteHeader(http.StatusCreated) _, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`)) case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"): _, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL()) default: f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path) } })) t.Cleanup(srv.Close) be, err := NewReplicate("flux-test", map[string]any{ "model": "black-forest-labs/flux-dev:abc123", "api_base": srv.URL, }) if err != nil { t.Fatalf("NewReplicate: %v", err) } r := be.(*Replicate) r.apiToken = "fake" r.pollInterval = time.Millisecond res, err := r.Generate(context.Background(), Request{Prompt: "x"}) if err != nil { t.Fatalf("Generate: %v", err) } res.ImageReader.Close() mu.Lock() defer mu.Unlock() if sentPath != "/v1/predictions" { t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath) } var body map[string]any if err := json.Unmarshal(sentBody, &body); err != nil { t.Fatalf("unmarshal sent body: %v", err) } if body["version"] != "abc123" { t.Errorf("expected version=abc123 in body, got %v", body["version"]) } } func TestReplicate401SurfacesEnvHint(t *testing.T) { f := newFakeReplicate(t) f.auth401 = true r := newReplicate(t, f, "black-forest-labs/flux-schnell") _, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err == nil { t.Fatal("expected 401 to surface as error") } msg := err.Error() if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") { t.Errorf("error should name env var: %v", err) } if !strings.Contains(msg, "401") { t.Errorf("error should mention 401: %v", err) } } func TestReplicate429RetriesThenSucceeds(t *testing.T) { f := newFakeReplicate(t) f.create429Until = 2 // first two calls 429 then succeed f.retryAfter = "" // force the adapter's exp-backoff path r := newReplicate(t, f, "black-forest-labs/flux-schnell") r.pollInterval = time.Millisecond // Squash the backoff so the test is fast. r.httpClient = &http.Client{Timeout: 5 * time.Second} ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() res, err := r.Generate(ctx, Request{Prompt: "p"}) if err != nil { t.Fatalf("Generate after 429s: %v", err) } res.ImageReader.Close() if got := f.createCalls.Load(); got != 3 { t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got) } } func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) { f := newFakeReplicate(t) f.create429Until = 99 // every call 429 r := newReplicate(t, f, "black-forest-labs/flux-schnell") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() _, err := r.Generate(ctx, Request{Prompt: "p"}) if err == nil { t.Fatal("expected error after sustained 429s") } if !strings.Contains(err.Error(), "429") { t.Errorf("expected 429 in error: %v", err) } // max429Retries=3 → 1 initial + 3 retries = 4 total if got := f.createCalls.Load(); got != 4 { t.Errorf("expected 4 create calls (1+3 retries), got %d", got) } } func TestReplicateFailedPredictionSurfacesError(t *testing.T) { f := newFakeReplicate(t) f.pollResponses = []string{ `{"id":"pred-abc","status":"starting","version":"v1","output":null}`, `{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`, } r := newReplicate(t, f, "black-forest-labs/flux-schnell") _, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err == nil { t.Fatal("expected error for failed prediction") } if !strings.Contains(err.Error(), "failed") { t.Errorf("error should mention failure: %v", err) } if !strings.Contains(err.Error(), "NSFW") { t.Errorf("error should include the API error message: %v", err) } } func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) { f := newFakeReplicate(t) // Always return processing → adapter times out. f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`} r := newReplicate(t, f, "black-forest-labs/flux-schnell") r.pollInterval = 5 * time.Millisecond r.pollTimeout = 30 * time.Millisecond _, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err == nil { t.Fatal("expected timeout error") } if !strings.Contains(err.Error(), "did not complete") { t.Errorf("expected 'did not complete' in error, got %v", err) } if !strings.Contains(err.Error(), "waited") { t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err) } } func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) { f := newFakeReplicate(t) f.image5xxFirst = 1 // first download 502, second OK r := newReplicate(t, f, "black-forest-labs/flux-schnell") res, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err != nil { t.Fatalf("Generate (download retry): %v", err) } res.ImageReader.Close() if got := f.imageCalls.Load(); got != 2 { t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got) } } func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) { f := newFakeReplicate(t) f.image5xxFirst = 99 // every download fails r := newReplicate(t, f, "black-forest-labs/flux-schnell") _, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err == nil { t.Fatal("expected error after sustained image-download 5xx") } if got := f.imageCalls.Load(); got != 2 { t.Errorf("expected 2 image fetches (no further retries), got %d", got) } } func TestReplicateContextCancelStopsPolling(t *testing.T) { f := newFakeReplicate(t) f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`} r := newReplicate(t, f, "black-forest-labs/flux-schnell") r.pollInterval = 5 * time.Millisecond r.pollTimeout = 5 * time.Second ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) defer cancel() _, err := r.Generate(ctx, Request{Prompt: "p"}) if err == nil { t.Fatal("expected ctx error") } if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") { t.Errorf("expected deadline exceeded, got %v", err) } } func TestReplicateBackendOptsMergedIntoInput(t *testing.T) { var captured map[string]any srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"): body, _ := io.ReadAll(r.Body) var top struct { Input map[string]any `json:"input"` } _ = json.Unmarshal(body, &top) captured = top.Input w.WriteHeader(http.StatusCreated) _, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`)) } })) t.Cleanup(srv.Close) be, err := NewReplicate("flux-test", map[string]any{ "model": "black-forest-labs/flux-schnell", "api_base": srv.URL, }) if err != nil { t.Fatalf("NewReplicate: %v", err) } r := be.(*Replicate) r.apiToken = "fake" r.pollInterval = time.Millisecond // We expect this to fail at "pickFirstOutputURL" (empty output) but the // captured input should still have been recorded. _, _ = r.Generate(context.Background(), Request{ Prompt: "p", Width: 1024, Height: 1024, BackendOpts: map[string]any{ "output_quality": 90, "go_fast": true, }, }) if captured == nil { t.Fatal("create endpoint not hit") } if captured["output_quality"] != float64(90) { t.Errorf("output_quality not threaded: %v", captured["output_quality"]) } if captured["go_fast"] != true { t.Errorf("go_fast not threaded: %v", captured["go_fast"]) } if captured["prompt"] != "p" { t.Errorf("prompt not threaded: %v", captured["prompt"]) } } func TestReplicateOutputArrayAccepted(t *testing.T) { f := newFakeReplicate(t) f.pollResponses = []string{ fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()), } r := newReplicate(t, f, "black-forest-labs/flux-schnell") res, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err != nil { t.Fatalf("Generate (output as array): %v", err) } res.ImageReader.Close() } func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) { f := newFakeReplicate(t) r := newReplicate(t, f, "stability-ai/sdxl") sink := &recordingSink{} r.Sink = sink res, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err != nil { t.Fatalf("Generate (unknown model): %v", err) } res.ImageReader.Close() if _, present := res.Metadata["cost_usd_estimate"]; present { t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"]) } if len(sink.rows) != 1 { t.Fatalf("expected 1 sink row, got %d", len(sink.rows)) } if sink.rows[0].CostUSDEstimate != nil { t.Errorf("expected nil cost in sink row for unknown model") } } func TestReplicateSinkFailureIsWarningNotError(t *testing.T) { f := newFakeReplicate(t) r := newReplicate(t, f, "black-forest-labs/flux-schnell") r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") }) res, err := r.Generate(context.Background(), Request{Prompt: "p"}) if err != nil { t.Fatalf("sink failure should not fail Generate: %v", err) } res.ImageReader.Close() } func TestReplicateDefaultStepsApplied(t *testing.T) { var captured map[string]any srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"): body, _ := io.ReadAll(r.Body) var top struct { Input map[string]any `json:"input"` } _ = json.Unmarshal(body, &top) captured = top.Input w.WriteHeader(http.StatusCreated) _, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`)) } })) t.Cleanup(srv.Close) be, err := NewReplicate("flux-test", map[string]any{ "model": "black-forest-labs/flux-dev", "api_base": srv.URL, "default_steps": 28, }) if err != nil { t.Fatalf("NewReplicate: %v", err) } r := be.(*Replicate) r.apiToken = "fake" r.pollInterval = time.Millisecond _, _ = r.Generate(context.Background(), Request{Prompt: "p"}) if captured == nil { t.Fatal("create endpoint not hit") } if captured["num_inference_steps"] != float64(28) { t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"]) } } func TestComputeAspectRatio(t *testing.T) { cases := []struct { w, h int fallback string want string }{ {1024, 1024, "1:1", "1:1"}, {1920, 1080, "1:1", "16:9"}, {2560, 1440, "1:1", "16:9"}, {1024, 768, "1:1", "4:3"}, {1024, 1280, "1:1", "4:5"}, {1000, 1234, "1:1", "1:1"}, // weird ratio falls back {0, 1024, "1:1", "1:1"}, } for _, c := range cases { got := computeAspectRatio(c.w, c.h, c.fallback) if got != c.want { t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want) } } } func TestParseModelRef(t *testing.T) { owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell") if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" { t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err) } owner, name, ver, err = parseModelRef("owner/name:hash123") if err != nil || owner != "owner" || name != "name" || ver != "hash123" { t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err) } if _, _, _, err := parseModelRef("noslash"); err == nil { t.Errorf("expected error for malformed ref") } } func TestHashPromptStable(t *testing.T) { a := hashPrompt("hello") b := hashPrompt("hello") c := hashPrompt("hello!") if a != b { t.Errorf("hashPrompt should be deterministic") } if a == c { t.Errorf("different prompts should hash differently") } if len(a) != 64 { t.Errorf("sha256 hex should be 64 chars, got %d", len(a)) } } func TestReplicatePricingKnownModels(t *testing.T) { if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 { t.Errorf("schnell rate = %v (ok=%v)", v, ok) } if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 { t.Errorf("dev rate = %v (ok=%v)", v, ok) } if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 { t.Errorf("versioned ref should resolve to base price: %v %v", v, ok) } if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok { t.Errorf("unknown model should report ok=false") } } func TestReplicateTypeIsRegistered(t *testing.T) { if !Default.Has(ReplicateType) { t.Errorf("replicate type not registered in Default") } } // recordingSink captures rows for assertion. type recordingSink struct { mu sync.Mutex rows []UsageRow } func (s *recordingSink) Record(_ context.Context, row UsageRow) error { s.mu.Lock() defer s.mu.Unlock() s.rows = append(s.rows, row) return nil } type sinkFunc func(context.Context, UsageRow) error func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }