package backend import ( "bytes" "context" "encoding/json" "fmt" "image" "image/color" "image/png" "io" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" ) // fakeComfy is a programmable mock of the ComfyUI HTTP API. Tests configure // its behaviour by adjusting the public fields before issuing the request. type fakeComfy struct { t *testing.T // /prompt promptStatus int promptBody []byte promptCalls atomic.Int32 failPromptUntil int32 // first N /prompt calls return promptFailStatus promptFailStatus int promptFailBody []byte // /history — start by returning {} (no entry), flip to completed once // historyReadyAfter polls have happened. historyReadyAfter int32 historyCalls atomic.Int32 historyError bool // /view viewStatus int viewBody []byte viewType string // /system_stats statsTotal int64 statsFree int64 server *httptest.Server } func (f *fakeComfy) handler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.Path == "/prompt" && r.Method == http.MethodPost: n := f.promptCalls.Add(1) if n <= int32(f.failPromptUntil) { w.WriteHeader(f.promptFailStatus) _, _ = w.Write(f.promptFailBody) return } w.WriteHeader(f.promptStatus) _, _ = w.Write(f.promptBody) case strings.HasPrefix(r.URL.Path, "/history/") && r.Method == http.MethodGet: n := f.historyCalls.Add(1) id := strings.TrimPrefix(r.URL.Path, "/history/") w.WriteHeader(http.StatusOK) if f.historyError { _, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":false,"status_str":"error"},"outputs":{}}}`, id) return } if n <= f.historyReadyAfter { _, _ = w.Write([]byte(`{}`)) return } _, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`, id, ) case r.URL.Path == "/view" && r.Method == http.MethodGet: ct := f.viewType if ct == "" { ct = "image/png" } w.Header().Set("Content-Type", ct) w.WriteHeader(f.viewStatus) _, _ = w.Write(f.viewBody) case r.URL.Path == "/system_stats" && r.Method == http.MethodGet: w.Header().Set("Content-Type", "application/json") body := map[string]any{ "system": map[string]any{}, "devices": []map[string]any{ {"vram_total": f.statsTotal, "vram_free": f.statsFree}, }, } _ = json.NewEncoder(w).Encode(body) default: f.t.Errorf("fakeComfy: unexpected request %s %s", r.Method, r.URL.Path) http.NotFound(w, r) } }) } func (f *fakeComfy) start() { f.server = httptest.NewServer(f.handler()) f.t.Cleanup(f.server.Close) } // newFakeComfy spins up a fakeComfy with happy-path defaults. func newFakeComfy(t *testing.T) *fakeComfy { t.Helper() f := &fakeComfy{ t: t, promptStatus: http.StatusOK, promptBody: []byte(`{"prompt_id":"pid-abc","number":1,"node_errors":{}}`), viewStatus: http.StatusOK, viewBody: mustPNG(t, 16, 16), statsTotal: 16 * 1024 * 1024 * 1024, statsFree: 8 * 1024 * 1024 * 1024, } f.start() return f } // newComfy returns a Comfy pointed at f, with poll interval squashed for fast // tests and deterministic seed/client_id. func newComfy(t *testing.T, f *fakeComfy) *Comfy { t.Helper() be, err := NewComfy("flux-test", map[string]any{ "base_url": f.server.URL, "model": "flux1-schnell.safetensors", "default_steps": 4, }) if err != nil { t.Fatalf("NewComfy: %v", err) } c := be.(*Comfy) c.pollInterval = time.Millisecond c.pollTimeout = 5 * time.Second c.randSeed = func() int64 { return 42 } c.clientIDFn = func() string { return "imagen-test" } return c } func mustPNG(t *testing.T, w, h int) []byte { t.Helper() img := image.NewRGBA(image.Rect(0, 0, w, h)) for y := range h { for x := range w { img.Set(x, y, color.RGBA{R: 200, G: 100, B: 50, A: 255}) } } var buf bytes.Buffer if err := png.Encode(&buf, img); err != nil { t.Fatalf("encode png: %v", err) } return buf.Bytes() } func TestComfyConstructorRequiresBaseAndModel(t *testing.T) { if _, err := NewComfy("x", map[string]any{}); err == nil { t.Errorf("expected error for missing base_url") } if _, err := NewComfy("x", map[string]any{"base_url": "http://h:1"}); err == nil { t.Errorf("expected error for missing model") } if _, err := NewComfy("", map[string]any{"base_url": "http://h:1", "model": "m"}); err == nil { t.Errorf("expected error for empty instance name") } } func TestComfyHappyPath(t *testing.T) { f := newFakeComfy(t) f.historyReadyAfter = 2 // exercise the polling loop c := newComfy(t, f) res, err := c.Generate(context.Background(), Request{ Prompt: "a small fishbowl with a cat", Width: 512, Height: 512, Steps: 4, Seed: 1234567, }) if err != nil { t.Fatalf("Generate: %v", err) } defer res.ImageReader.Close() if res.MimeType != "image/png" { t.Errorf("mime = %q", res.MimeType) } body, err := io.ReadAll(res.ImageReader) if err != nil { t.Fatalf("read body: %v", err) } if !bytes.Equal(body, f.viewBody) { t.Errorf("image body did not round-trip") } if seed, _ := res.Metadata["seed"].(int64); seed != 1234567 { t.Errorf("metadata seed = %v", res.Metadata["seed"]) } if model, _ := res.Metadata["model"].(string); model != "flux1-schnell.safetensors" { t.Errorf("metadata model = %v", res.Metadata["model"]) } if steps, _ := res.Metadata["steps"].(int); steps != 4 { t.Errorf("metadata steps = %v", res.Metadata["steps"]) } if pid, _ := res.Metadata["prompt_id"].(string); pid != "pid-abc" { t.Errorf("metadata prompt_id = %v", res.Metadata["prompt_id"]) } if _, ok := res.Metadata["latency_ms"]; !ok { t.Errorf("metadata missing latency_ms") } // vram_used_mib is best-effort but should be present given our mock stats if vram, _ := res.Metadata["vram_used_mib"].(int64); vram != 8192 { t.Errorf("metadata vram_used_mib = %v, want 8192", res.Metadata["vram_used_mib"]) } if got := f.historyCalls.Load(); got < 3 { t.Errorf("expected at least 3 /history polls, got %d", got) } } func TestComfyDefaultsAppliedWhenZero(t *testing.T) { f := newFakeComfy(t) c := newComfy(t, f) res, err := c.Generate(context.Background(), Request{Prompt: "p"}) // all-zero if err != nil { t.Fatalf("Generate: %v", err) } defer res.ImageReader.Close() _, _ = io.ReadAll(res.ImageReader) if w, _ := res.Metadata["width"].(int); w != 1024 { t.Errorf("width default = %v", res.Metadata["width"]) } if steps, _ := res.Metadata["steps"].(int); steps != 4 { t.Errorf("steps default = %v", res.Metadata["steps"]) } if seed, _ := res.Metadata["seed"].(int64); seed != 42 { t.Errorf("seed default (test rand hook) = %v", res.Metadata["seed"]) } if s, _ := res.Metadata["sampler"].(string); s != "euler" { t.Errorf("sampler default = %q", s) } } func TestComfyPromptRetriesOnce5xx(t *testing.T) { f := newFakeComfy(t) f.failPromptUntil = 1 f.promptFailStatus = http.StatusBadGateway f.promptFailBody = []byte("upstream busy") c := newComfy(t, f) res, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err != nil { t.Fatalf("Generate (with one 502 then OK): %v", err) } defer res.ImageReader.Close() _, _ = io.ReadAll(res.ImageReader) if got := f.promptCalls.Load(); got != 2 { t.Errorf("expected exactly 2 /prompt calls (1 fail + 1 retry), got %d", got) } } func TestComfyPromptGivesUpAfterTwo5xx(t *testing.T) { f := newFakeComfy(t) f.failPromptUntil = 99 // every call fails f.promptFailStatus = http.StatusServiceUnavailable f.promptFailBody = []byte("nope") c := newComfy(t, f) _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error after sustained 503s") } if !strings.Contains(err.Error(), "503") { t.Errorf("expected error to mention 503, got %v", err) } if got := f.promptCalls.Load(); got != 2 { t.Errorf("expected exactly 2 /prompt calls (no further retries), got %d", got) } } func TestComfyPromptDoesNotRetryOn4xx(t *testing.T) { f := newFakeComfy(t) f.failPromptUntil = 99 f.promptFailStatus = http.StatusBadRequest f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation"},"node_errors":{"some":"thing"}}`) c := newComfy(t, f) _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error for 400") } if got := f.promptCalls.Load(); got != 1 { t.Errorf("expected exactly 1 /prompt call (no retry on 4xx), got %d", got) } } func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) { f := newFakeComfy(t) f.failPromptUntil = 99 f.promptFailStatus = http.StatusBadRequest f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation","message":"Prompt outputs failed validation"},"node_errors":{"12":{"errors":[{"type":"value_not_in_list","message":"Value not in list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`) c := newComfy(t, f) _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error") } msg := err.Error() if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") { t.Errorf("error should point at the setup doc, got %v", err) } if !strings.Contains(msg, "flux1-schnell.safetensors") { t.Errorf("error should name the missing model, got %v", err) } } func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) { // Older ComfyUI builds 200 a workflow-validation failure. f := newFakeComfy(t) f.promptStatus = http.StatusOK f.promptBody = []byte(`{"prompt_id":"","node_errors":{"12":{"errors":[{"type":"value_not_in_list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`) c := newComfy(t, f) _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error for node_errors on 200") } if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") { t.Errorf("error should point at the setup doc, got %v", err) } } func TestComfyHistoryErrorSurfaced(t *testing.T) { f := newFakeComfy(t) f.historyError = true c := newComfy(t, f) _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error when history reports execution error") } if !strings.Contains(err.Error(), "errored") { t.Errorf("expected 'errored' in message, got %v", err) } } func TestComfyViewFailureSurfaced(t *testing.T) { f := newFakeComfy(t) f.viewStatus = http.StatusNotFound f.viewBody = []byte("nope") c := newComfy(t, f) _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error when /view 404s") } if !strings.Contains(err.Error(), "404") { t.Errorf("expected status code in error, got %v", err) } } func TestComfyUnreachableHostMentionsBootHelper(t *testing.T) { be, err := NewComfy("flux-test", map[string]any{ "base_url": "http://127.0.0.1:1", // closed port; connection refused "model": "flux1-schnell.safetensors", }) if err != nil { t.Fatalf("NewComfy: %v", err) } c := be.(*Comfy) c.httpClient = &http.Client{Timeout: 500 * time.Millisecond} _, err = c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected error for unreachable host") } if !strings.Contains(err.Error(), "boot-whitetower mrock") { t.Errorf("expected boot-helper hint, got %v", err) } } func TestComfyContextCancelStopsPolling(t *testing.T) { f := newFakeComfy(t) f.historyReadyAfter = 1_000_000 // never finishes c := newComfy(t, f) c.pollInterval = 5 * time.Millisecond ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) defer cancel() _, err := c.Generate(ctx, Request{Prompt: "p", Width: 64, Height: 64}) if err == nil { t.Fatal("expected ctx.Err()") } if !strings.Contains(err.Error(), "context deadline exceeded") { t.Errorf("expected deadline exceeded, got %v", err) } } func TestComfyWorkflowReflectsRequest(t *testing.T) { // Capture the workflow body to assert KSampler + EmptyLatentImage values. var captured []byte srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/prompt": captured, _ = io.ReadAll(r.Body) _, _ = w.Write([]byte(`{"prompt_id":"pid","number":1,"node_errors":{}}`)) case "/history/pid": _, _ = w.Write([]byte(`{"pid":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`)) case "/view": _, _ = w.Write(mustPNG(t, 8, 8)) case "/system_stats": _, _ = w.Write([]byte(`{"devices":[{"vram_total":1,"vram_free":1}]}`)) default: http.NotFound(w, r) } })) t.Cleanup(srv.Close) be, err := NewComfy("flux-test", map[string]any{ "base_url": srv.URL, "model": "custom.safetensors", "default_steps": 7, "default_sampler": "dpmpp_2m", "default_scheduler": "karras", }) if err != nil { t.Fatalf("NewComfy: %v", err) } c := be.(*Comfy) c.pollInterval = time.Millisecond c.randSeed = func() int64 { return 9999 } res, err := c.Generate(context.Background(), Request{ Prompt: "a cat", NegativePrompt: "blurry", Width: 768, Height: 512, Steps: 11, Seed: 555, }) if err != nil { t.Fatalf("Generate: %v", err) } res.ImageReader.Close() var sent struct { Prompt map[string]map[string]any `json:"prompt"` ClientID string `json:"client_id"` } if err := json.Unmarshal(captured, &sent); err != nil { t.Fatalf("unmarshal captured: %v", err) } ks := sent.Prompt["31"]["inputs"].(map[string]any) if ks["seed"].(float64) != 555 { t.Errorf("KSampler seed = %v, want 555", ks["seed"]) } if ks["steps"].(float64) != 11 { t.Errorf("KSampler steps = %v, want 11", ks["steps"]) } if ks["sampler_name"].(string) != "dpmpp_2m" { t.Errorf("sampler_name = %v", ks["sampler_name"]) } if ks["scheduler"].(string) != "karras" { t.Errorf("scheduler = %v", ks["scheduler"]) } latent := sent.Prompt["27"]["inputs"].(map[string]any) if latent["width"].(float64) != 768 || latent["height"].(float64) != 512 { t.Errorf("EmptySD3LatentImage size = %vx%v", latent["width"], latent["height"]) } unet := sent.Prompt["12"]["inputs"].(map[string]any) if unet["unet_name"].(string) != "custom.safetensors" { t.Errorf("unet_name = %v", unet["unet_name"]) } neg := sent.Prompt["13"]["inputs"].(map[string]any) if neg["text"].(string) != "blurry" { t.Errorf("negative prompt not threaded: %v", neg["text"]) } if !strings.HasPrefix(sent.ClientID, "imagen-") && sent.ClientID == "" { t.Errorf("client_id should be set: %q", sent.ClientID) } } func TestComfyTypeIsRegistered(t *testing.T) { if !Default.Has(ComfyType) { t.Errorf("comfyui type not registered in Default") } }