diff --git a/internal/backend/comfyui.go b/internal/backend/comfyui.go new file mode 100644 index 0000000..442e1e2 --- /dev/null +++ b/internal/backend/comfyui.go @@ -0,0 +1,557 @@ +package backend + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// ComfyType is the type-name adapters register under for ComfyUI instances. +const ComfyType = "comfyui" + +// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history` +// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from +// the values in Request. +// +// Concurrency: a single Comfy is safe to share across goroutines as long as +// the underlying http.Client is. Generate does not hold long-lived state. +type Comfy struct { + instance string + + base string + model string + vae string + clipL string + clipT5 string + dtype string + + defaultSteps int + defaultSampler string + defaultScheduler string + + httpClient *http.Client + pollInterval time.Duration + pollTimeout time.Duration + + // Hooks for tests; production paths use the package-level defaults. + randSeed func() int64 + clientIDFn func() string +} + +// NewComfy is the registry constructor. cfg is the adapter's slice of +// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX +// schnell defaults. +func NewComfy(name string, cfg map[string]any) (Backend, error) { + if name == "" { + return nil, fmt.Errorf("comfyui: empty instance name") + } + base := strings.TrimRight(getString(cfg, "base_url", ""), "/") + if base == "" { + return nil, fmt.Errorf("comfyui[%s]: base_url is required", name) + } + if _, err := url.Parse(base); err != nil { + return nil, fmt.Errorf("comfyui[%s]: base_url %q invalid: %w", name, base, err) + } + model := getString(cfg, "model", "") + if model == "" { + return nil, fmt.Errorf("comfyui[%s]: model is required", name) + } + + c := &Comfy{ + instance: name, + base: base, + model: model, + + vae: getString(cfg, "vae", "ae.safetensors"), + clipL: getString(cfg, "clip_l", "clip_l.safetensors"), + clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"), + dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"), + + defaultSteps: getInt(cfg, "default_steps", 4), + defaultSampler: getString(cfg, "default_sampler", "euler"), + defaultScheduler: getString(cfg, "default_scheduler", "simple"), + + httpClient: &http.Client{Timeout: 60 * time.Second}, + pollInterval: 250 * time.Millisecond, + pollTimeout: 120 * time.Second, + + randSeed: cryptoSeed, + clientIDFn: randClientID, + } + return c, nil +} + +// Name returns the instance name from imagen.yaml. +func (c *Comfy) Name() string { return c.instance } + +// Generate submits one workflow to ComfyUI, waits for it to render, and +// returns the resulting PNG. +func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) { + width := orDefaultInt(req.Width, 1024) + height := orDefaultInt(req.Height, 1024) + steps := orDefaultInt(req.Steps, c.defaultSteps) + + sampler := c.defaultSampler + scheduler := c.defaultScheduler + if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" { + sampler = v + } + if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" { + scheduler = v + } + + seed := req.Seed + if seed == 0 { + seed = c.randSeed() + } + + workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler) + clientID := c.clientIDFn() + + start := time.Now() + promptID, err := c.submitPrompt(ctx, workflow, clientID) + if err != nil { + return nil, err + } + filename, err := c.waitForCompletion(ctx, promptID) + if err != nil { + return nil, err + } + imgBytes, err := c.fetchImage(ctx, filename) + if err != nil { + return nil, err + } + latencyMs := time.Since(start).Milliseconds() + + meta := map[string]any{ + "backend": c.instance, + "backend_type": ComfyType, + "model": c.model, + "seed": seed, + "steps": steps, + "sampler": sampler, + "scheduler": scheduler, + "width": width, + "height": height, + "latency_ms": latencyMs, + "prompt_id": promptID, + "client_id": clientID, + } + if vram := c.vramUsedMiB(ctx); vram > 0 { + meta["vram_used_mib"] = vram + } + + return &Result{ + ImageReader: io.NopCloser(bytes.NewReader(imgBytes)), + MimeType: "image/png", + Metadata: meta, + }, nil +} + +// submitPrompt POSTs the workflow and extracts the prompt_id. +// +// Retries once on a 5xx or transient network error. 4xx responses are not +// retried — they are treated as configuration bugs (missing model, bad +// workflow shape, etc.) and surfaced with a hint pointing at the docs when +// the body matches a known pattern. +func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clientID string) (string, error) { + body, err := json.Marshal(map[string]any{ + "prompt": workflow, + "client_id": clientID, + }) + if err != nil { + return "", fmt.Errorf("comfyui: marshal workflow: %w", err) + } + + var lastErr error + for attempt := range 2 { + if attempt > 0 { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(time.Second): + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/prompt", bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = c.connError(err) + continue + } + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + switch { + case resp.StatusCode >= 200 && resp.StatusCode < 300: + return parsePromptID(respBody, c.model) + case resp.StatusCode >= 500: + lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody)) + continue + default: + return "", c.classifyBadRequest(resp.StatusCode, respBody) + } + } + return "", lastErr +} + +// waitForCompletion polls /history/{id} until the prompt finishes and +// returns the filename of the produced image. +func (c *Comfy) waitForCompletion(ctx context.Context, promptID string) (string, error) { + deadline := time.Now().Add(c.pollTimeout) + for { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + if time.Now().After(deadline) { + return "", fmt.Errorf("comfyui: prompt %s did not complete within %s", promptID, c.pollTimeout) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/history/"+promptID, nil) + if err != nil { + return "", err + } + resp, err := c.httpClient.Do(req) + if err != nil { + return "", c.connError(err) + } + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("comfyui /history/%s %d: %s", promptID, resp.StatusCode, snip(body)) + } + filename, done, err := parseHistory(body, promptID) + if err != nil { + return "", err + } + if done { + return filename, nil + } + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(c.pollInterval): + } + } +} + +// fetchImage downloads the produced image bytes via /view. +func (c *Comfy) fetchImage(ctx context.Context, filename string) ([]byte, error) { + q := url.Values{ + "filename": {filename}, + "type": {"output"}, + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/view?"+q.Encode(), nil) + if err != nil { + return nil, err + } + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, c.connError(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("comfyui /view %d: %s", resp.StatusCode, snip(body)) + } + return io.ReadAll(resp.Body) +} + +// vramUsedMiB returns total - free VRAM on device 0 from /system_stats, or +// 0 if the endpoint isn't available. Best-effort, never an error. +func (c *Comfy) vramUsedMiB(ctx context.Context) int64 { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/system_stats", nil) + if err != nil { + return 0 + } + resp, err := c.httpClient.Do(req) + if err != nil { + return 0 + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return 0 + } + var s struct { + Devices []struct { + VRAMTotal int64 `json:"vram_total"` + VRAMFree int64 `json:"vram_free"` + } `json:"devices"` + } + if err := json.NewDecoder(resp.Body).Decode(&s); err != nil { + return 0 + } + if len(s.Devices) == 0 { + return 0 + } + used := s.Devices[0].VRAMTotal - s.Devices[0].VRAMFree + if used < 0 { + return 0 + } + return used / (1024 * 1024) +} + +// connError translates a Go networking error into a user-actionable message, +// pointing at the boot-whitetower script when mRock looks asleep. +func (c *Comfy) connError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return err + } + msg := err.Error() + var opErr *net.OpError + asOp := errors.As(err, &opErr) + switch { + case asOp, + strings.Contains(msg, "connection refused"), + strings.Contains(msg, "no such host"), + strings.Contains(msg, "no route to host"), + strings.Contains(msg, "network is unreachable"), + strings.Contains(msg, "i/o timeout"): + return fmt.Errorf("comfyui at %s unreachable (%v) — if mRock is asleep, run: boot-whitetower mrock", c.base, err) + } + return fmt.Errorf("comfyui at %s: %w", c.base, err) +} + +// classifyBadRequest interprets a 4xx body. Some ComfyUI builds use 400 for +// workflow-validation failures and put the diagnostics in node_errors; older +// builds use 200 + node_errors. This handles the 4xx flavour. +func (c *Comfy) classifyBadRequest(status int, body []byte) error { + if hint, ok := missingModelHint(body, c.model); ok { + return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint) + } + return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body)) +} + +// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow, +// node-IDs matching the upstream "flux-schnell" template so anyone debugging +// in the ComfyUI UI sees a familiar shape. +func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any { + return map[string]any{ + "6": map[string]any{ + "class_type": "CLIPTextEncode", + "inputs": map[string]any{ + "text": prompt, + "clip": []any{"11", 0}, + }, + }, + "8": map[string]any{ + "class_type": "VAEDecode", + "inputs": map[string]any{ + "samples": []any{"31", 0}, + "vae": []any{"10", 0}, + }, + }, + "9": map[string]any{ + "class_type": "SaveImage", + "inputs": map[string]any{ + "filename_prefix": "imagen", + "images": []any{"8", 0}, + }, + }, + "10": map[string]any{ + "class_type": "VAELoader", + "inputs": map[string]any{"vae_name": c.vae}, + }, + "11": map[string]any{ + "class_type": "DualCLIPLoader", + "inputs": map[string]any{ + "clip_name1": c.clipT5, + "clip_name2": c.clipL, + "type": "flux", + }, + }, + "12": map[string]any{ + "class_type": "UNETLoader", + "inputs": map[string]any{ + "unet_name": c.model, + "weight_dtype": c.dtype, + }, + }, + "13": map[string]any{ + "class_type": "CLIPTextEncode", + "inputs": map[string]any{ + "text": negative, + "clip": []any{"11", 0}, + }, + }, + "27": map[string]any{ + "class_type": "EmptySD3LatentImage", + "inputs": map[string]any{ + "width": w, + "height": h, + "batch_size": 1, + }, + }, + "30": map[string]any{ + "class_type": "ModelSamplingFlux", + "inputs": map[string]any{ + "model": []any{"12", 0}, + "max_shift": 1.15, + "base_shift": 0.5, + "width": w, + "height": h, + }, + }, + "31": map[string]any{ + "class_type": "KSampler", + "inputs": map[string]any{ + "model": []any{"30", 0}, + "seed": seed, + "steps": steps, + "cfg": 1.0, + "sampler_name": sampler, + "scheduler": scheduler, + "denoise": 1.0, + "positive": []any{"6", 0}, + "negative": []any{"13", 0}, + "latent_image": []any{"27", 0}, + }, + }, + } +} + +// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a +// validation failure and stuffs node_errors in the body — this function +// turns that into the same user-facing error as a 4xx with the same body. +func parsePromptID(body []byte, model string) (string, error) { + var resp struct { + PromptID string `json:"prompt_id"` + NodeErrors map[string]any `json:"node_errors"` + Error json.RawMessage `json:"error"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("comfyui /prompt: parse response: %w (body: %s)", err, snip(body)) + } + if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 { + if hint, ok := missingModelHint(body, model); ok { + return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint) + } + return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body)) + } + if resp.PromptID == "" { + return "", fmt.Errorf("comfyui /prompt: empty prompt_id (body: %s)", snip(body)) + } + return resp.PromptID, nil +} + +// parseHistory inspects a /history/{id} body and returns either the produced +// filename + done=true, or done=false to signal "keep polling". +func parseHistory(body []byte, promptID string) (string, bool, error) { + var entries map[string]struct { + Status struct { + Completed bool `json:"completed"` + StatusStr string `json:"status_str"` + } `json:"status"` + Outputs map[string]struct { + Images []struct { + Filename string `json:"filename"` + Subfolder string `json:"subfolder"` + Type string `json:"type"` + } `json:"images"` + } `json:"outputs"` + } + if err := json.Unmarshal(body, &entries); err != nil { + return "", false, fmt.Errorf("comfyui /history: parse: %w (body: %s)", err, snip(body)) + } + e, ok := entries[promptID] + if !ok { + return "", false, nil + } + if e.Status.StatusStr == "error" { + return "", false, fmt.Errorf("comfyui prompt %s errored: %s", promptID, snip(body)) + } + if !e.Status.Completed { + return "", false, nil + } + for _, out := range e.Outputs { + if len(out.Images) > 0 { + return out.Images[0].Filename, true, nil + } + } + return "", true, fmt.Errorf("comfyui prompt %s completed but produced no images", promptID) +} + +// missingModelHint returns a user-actionable message when the response body +// indicates the configured unet model isn't loaded on the server. ComfyUI +// uses both the human-readable "Value not in list" message and the enum +// "value_not_in_list" type — match either. +func missingModelHint(body []byte, model string) (string, bool) { + s := string(body) + hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list") + if hasMarker && strings.Contains(s, "unet_name") { + return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true + } + return "", false +} + +func cryptoSeed() int64 { + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + return time.Now().UnixNano() + } + return int64(binary.BigEndian.Uint64(b[:]) >> 1) +} + +func randClientID() string { + var b [8]byte + _, _ = rand.Read(b[:]) + return fmt.Sprintf("imagen-%x", b) +} + +func getString(m map[string]any, k, def string) string { + if v, ok := m[k].(string); ok && v != "" { + return v + } + return def +} + +func getInt(m map[string]any, k string, def int) int { + if v, ok := m[k]; ok { + switch n := v.(type) { + case int: + return n + case int64: + return int(n) + case float64: + return int(n) + } + } + return def +} + +func orDefaultInt(v, def int) int { + if v == 0 { + return def + } + return v +} + +func snip(b []byte) string { + const max = 500 + s := strings.TrimSpace(string(b)) + if len(s) > max { + s = s[:max] + "..." + } + return s +} + +func init() { + Register(ComfyType, NewComfy) +} diff --git a/internal/backend/comfyui_test.go b/internal/backend/comfyui_test.go new file mode 100644 index 0000000..778961d --- /dev/null +++ b/internal/backend/comfyui_test.go @@ -0,0 +1,494 @@ +package backend + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "image" + "image/color" + "image/png" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +// fakeComfy is a programmable mock of the ComfyUI HTTP API. Tests configure +// its behaviour by adjusting the public fields before issuing the request. +type fakeComfy struct { + t *testing.T + + // /prompt + promptStatus int + promptBody []byte + promptCalls atomic.Int32 + failPromptUntil int32 // first N /prompt calls return promptFailStatus + promptFailStatus int + promptFailBody []byte + + // /history — start by returning {} (no entry), flip to completed once + // historyReadyAfter polls have happened. + historyReadyAfter int32 + historyCalls atomic.Int32 + historyError bool + + // /view + viewStatus int + viewBody []byte + viewType string + + // /system_stats + statsTotal int64 + statsFree int64 + + server *httptest.Server +} + +func (f *fakeComfy) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/prompt" && r.Method == http.MethodPost: + n := f.promptCalls.Add(1) + if n <= int32(f.failPromptUntil) { + w.WriteHeader(f.promptFailStatus) + _, _ = w.Write(f.promptFailBody) + return + } + w.WriteHeader(f.promptStatus) + _, _ = w.Write(f.promptBody) + case strings.HasPrefix(r.URL.Path, "/history/") && r.Method == http.MethodGet: + n := f.historyCalls.Add(1) + id := strings.TrimPrefix(r.URL.Path, "/history/") + w.WriteHeader(http.StatusOK) + if f.historyError { + _, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":false,"status_str":"error"},"outputs":{}}}`, id) + return + } + if n <= f.historyReadyAfter { + _, _ = w.Write([]byte(`{}`)) + return + } + _, _ = fmt.Fprintf(w, + `{"%s":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`, + id, + ) + case r.URL.Path == "/view" && r.Method == http.MethodGet: + ct := f.viewType + if ct == "" { + ct = "image/png" + } + w.Header().Set("Content-Type", ct) + w.WriteHeader(f.viewStatus) + _, _ = w.Write(f.viewBody) + case r.URL.Path == "/system_stats" && r.Method == http.MethodGet: + w.Header().Set("Content-Type", "application/json") + body := map[string]any{ + "system": map[string]any{}, + "devices": []map[string]any{ + {"vram_total": f.statsTotal, "vram_free": f.statsFree}, + }, + } + _ = json.NewEncoder(w).Encode(body) + default: + f.t.Errorf("fakeComfy: unexpected request %s %s", r.Method, r.URL.Path) + http.NotFound(w, r) + } + }) +} + +func (f *fakeComfy) start() { + f.server = httptest.NewServer(f.handler()) + f.t.Cleanup(f.server.Close) +} + +// newFakeComfy spins up a fakeComfy with happy-path defaults. +func newFakeComfy(t *testing.T) *fakeComfy { + t.Helper() + f := &fakeComfy{ + t: t, + promptStatus: http.StatusOK, + promptBody: []byte(`{"prompt_id":"pid-abc","number":1,"node_errors":{}}`), + viewStatus: http.StatusOK, + viewBody: mustPNG(t, 16, 16), + statsTotal: 16 * 1024 * 1024 * 1024, + statsFree: 8 * 1024 * 1024 * 1024, + } + f.start() + return f +} + +// newComfy returns a Comfy pointed at f, with poll interval squashed for fast +// tests and deterministic seed/client_id. +func newComfy(t *testing.T, f *fakeComfy) *Comfy { + t.Helper() + be, err := NewComfy("flux-test", map[string]any{ + "base_url": f.server.URL, + "model": "flux1-schnell.safetensors", + "default_steps": 4, + }) + if err != nil { + t.Fatalf("NewComfy: %v", err) + } + c := be.(*Comfy) + c.pollInterval = time.Millisecond + c.pollTimeout = 5 * time.Second + c.randSeed = func() int64 { return 42 } + c.clientIDFn = func() string { return "imagen-test" } + return c +} + +func mustPNG(t *testing.T, w, h int) []byte { + t.Helper() + img := image.NewRGBA(image.Rect(0, 0, w, h)) + for y := range h { + for x := range w { + img.Set(x, y, color.RGBA{R: 200, G: 100, B: 50, A: 255}) + } + } + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + t.Fatalf("encode png: %v", err) + } + return buf.Bytes() +} + +func TestComfyConstructorRequiresBaseAndModel(t *testing.T) { + if _, err := NewComfy("x", map[string]any{}); err == nil { + t.Errorf("expected error for missing base_url") + } + if _, err := NewComfy("x", map[string]any{"base_url": "http://h:1"}); err == nil { + t.Errorf("expected error for missing model") + } + if _, err := NewComfy("", map[string]any{"base_url": "http://h:1", "model": "m"}); err == nil { + t.Errorf("expected error for empty instance name") + } +} + +func TestComfyHappyPath(t *testing.T) { + f := newFakeComfy(t) + f.historyReadyAfter = 2 // exercise the polling loop + c := newComfy(t, f) + + res, err := c.Generate(context.Background(), Request{ + Prompt: "a small fishbowl with a cat", + Width: 512, + Height: 512, + Steps: 4, + Seed: 1234567, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + defer res.ImageReader.Close() + + if res.MimeType != "image/png" { + t.Errorf("mime = %q", res.MimeType) + } + body, err := io.ReadAll(res.ImageReader) + if err != nil { + t.Fatalf("read body: %v", err) + } + if !bytes.Equal(body, f.viewBody) { + t.Errorf("image body did not round-trip") + } + + if seed, _ := res.Metadata["seed"].(int64); seed != 1234567 { + t.Errorf("metadata seed = %v", res.Metadata["seed"]) + } + if model, _ := res.Metadata["model"].(string); model != "flux1-schnell.safetensors" { + t.Errorf("metadata model = %v", res.Metadata["model"]) + } + if steps, _ := res.Metadata["steps"].(int); steps != 4 { + t.Errorf("metadata steps = %v", res.Metadata["steps"]) + } + if pid, _ := res.Metadata["prompt_id"].(string); pid != "pid-abc" { + t.Errorf("metadata prompt_id = %v", res.Metadata["prompt_id"]) + } + if _, ok := res.Metadata["latency_ms"]; !ok { + t.Errorf("metadata missing latency_ms") + } + // vram_used_mib is best-effort but should be present given our mock stats + if vram, _ := res.Metadata["vram_used_mib"].(int64); vram != 8192 { + t.Errorf("metadata vram_used_mib = %v, want 8192", res.Metadata["vram_used_mib"]) + } + + if got := f.historyCalls.Load(); got < 3 { + t.Errorf("expected at least 3 /history polls, got %d", got) + } +} + +func TestComfyDefaultsAppliedWhenZero(t *testing.T) { + f := newFakeComfy(t) + c := newComfy(t, f) + + res, err := c.Generate(context.Background(), Request{Prompt: "p"}) // all-zero + if err != nil { + t.Fatalf("Generate: %v", err) + } + defer res.ImageReader.Close() + _, _ = io.ReadAll(res.ImageReader) + + if w, _ := res.Metadata["width"].(int); w != 1024 { + t.Errorf("width default = %v", res.Metadata["width"]) + } + if steps, _ := res.Metadata["steps"].(int); steps != 4 { + t.Errorf("steps default = %v", res.Metadata["steps"]) + } + if seed, _ := res.Metadata["seed"].(int64); seed != 42 { + t.Errorf("seed default (test rand hook) = %v", res.Metadata["seed"]) + } + if s, _ := res.Metadata["sampler"].(string); s != "euler" { + t.Errorf("sampler default = %q", s) + } +} + +func TestComfyPromptRetriesOnce5xx(t *testing.T) { + f := newFakeComfy(t) + f.failPromptUntil = 1 + f.promptFailStatus = http.StatusBadGateway + f.promptFailBody = []byte("upstream busy") + c := newComfy(t, f) + + res, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err != nil { + t.Fatalf("Generate (with one 502 then OK): %v", err) + } + defer res.ImageReader.Close() + _, _ = io.ReadAll(res.ImageReader) + + if got := f.promptCalls.Load(); got != 2 { + t.Errorf("expected exactly 2 /prompt calls (1 fail + 1 retry), got %d", got) + } +} + +func TestComfyPromptGivesUpAfterTwo5xx(t *testing.T) { + f := newFakeComfy(t) + f.failPromptUntil = 99 // every call fails + f.promptFailStatus = http.StatusServiceUnavailable + f.promptFailBody = []byte("nope") + c := newComfy(t, f) + + _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error after sustained 503s") + } + if !strings.Contains(err.Error(), "503") { + t.Errorf("expected error to mention 503, got %v", err) + } + if got := f.promptCalls.Load(); got != 2 { + t.Errorf("expected exactly 2 /prompt calls (no further retries), got %d", got) + } +} + +func TestComfyPromptDoesNotRetryOn4xx(t *testing.T) { + f := newFakeComfy(t) + f.failPromptUntil = 99 + f.promptFailStatus = http.StatusBadRequest + f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation"},"node_errors":{"some":"thing"}}`) + c := newComfy(t, f) + + _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error for 400") + } + if got := f.promptCalls.Load(); got != 1 { + t.Errorf("expected exactly 1 /prompt call (no retry on 4xx), got %d", got) + } +} + +func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) { + f := newFakeComfy(t) + f.failPromptUntil = 99 + f.promptFailStatus = http.StatusBadRequest + f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation","message":"Prompt outputs failed validation"},"node_errors":{"12":{"errors":[{"type":"value_not_in_list","message":"Value not in list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`) + c := newComfy(t, f) + + _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") { + t.Errorf("error should point at the setup doc, got %v", err) + } + if !strings.Contains(msg, "flux1-schnell.safetensors") { + t.Errorf("error should name the missing model, got %v", err) + } +} + +func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) { + // Older ComfyUI builds 200 a workflow-validation failure. + f := newFakeComfy(t) + f.promptStatus = http.StatusOK + f.promptBody = []byte(`{"prompt_id":"","node_errors":{"12":{"errors":[{"type":"value_not_in_list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`) + c := newComfy(t, f) + + _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error for node_errors on 200") + } + if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") { + t.Errorf("error should point at the setup doc, got %v", err) + } +} + +func TestComfyHistoryErrorSurfaced(t *testing.T) { + f := newFakeComfy(t) + f.historyError = true + c := newComfy(t, f) + + _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error when history reports execution error") + } + if !strings.Contains(err.Error(), "errored") { + t.Errorf("expected 'errored' in message, got %v", err) + } +} + +func TestComfyViewFailureSurfaced(t *testing.T) { + f := newFakeComfy(t) + f.viewStatus = http.StatusNotFound + f.viewBody = []byte("nope") + c := newComfy(t, f) + + _, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error when /view 404s") + } + if !strings.Contains(err.Error(), "404") { + t.Errorf("expected status code in error, got %v", err) + } +} + +func TestComfyUnreachableHostMentionsBootHelper(t *testing.T) { + be, err := NewComfy("flux-test", map[string]any{ + "base_url": "http://127.0.0.1:1", // closed port; connection refused + "model": "flux1-schnell.safetensors", + }) + if err != nil { + t.Fatalf("NewComfy: %v", err) + } + c := be.(*Comfy) + c.httpClient = &http.Client{Timeout: 500 * time.Millisecond} + + _, err = c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected error for unreachable host") + } + if !strings.Contains(err.Error(), "boot-whitetower mrock") { + t.Errorf("expected boot-helper hint, got %v", err) + } +} + +func TestComfyContextCancelStopsPolling(t *testing.T) { + f := newFakeComfy(t) + f.historyReadyAfter = 1_000_000 // never finishes + c := newComfy(t, f) + c.pollInterval = 5 * time.Millisecond + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + _, err := c.Generate(ctx, Request{Prompt: "p", Width: 64, Height: 64}) + if err == nil { + t.Fatal("expected ctx.Err()") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("expected deadline exceeded, got %v", err) + } +} + +func TestComfyWorkflowReflectsRequest(t *testing.T) { + // Capture the workflow body to assert KSampler + EmptyLatentImage values. + var captured []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/prompt": + captured, _ = io.ReadAll(r.Body) + _, _ = w.Write([]byte(`{"prompt_id":"pid","number":1,"node_errors":{}}`)) + case "/history/pid": + _, _ = w.Write([]byte(`{"pid":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`)) + case "/view": + _, _ = w.Write(mustPNG(t, 8, 8)) + case "/system_stats": + _, _ = w.Write([]byte(`{"devices":[{"vram_total":1,"vram_free":1}]}`)) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(srv.Close) + + be, err := NewComfy("flux-test", map[string]any{ + "base_url": srv.URL, + "model": "custom.safetensors", + "default_steps": 7, + "default_sampler": "dpmpp_2m", + "default_scheduler": "karras", + }) + if err != nil { + t.Fatalf("NewComfy: %v", err) + } + c := be.(*Comfy) + c.pollInterval = time.Millisecond + c.randSeed = func() int64 { return 9999 } + + res, err := c.Generate(context.Background(), Request{ + Prompt: "a cat", + NegativePrompt: "blurry", + Width: 768, + Height: 512, + Steps: 11, + Seed: 555, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + res.ImageReader.Close() + + var sent struct { + Prompt map[string]map[string]any `json:"prompt"` + ClientID string `json:"client_id"` + } + if err := json.Unmarshal(captured, &sent); err != nil { + t.Fatalf("unmarshal captured: %v", err) + } + ks := sent.Prompt["31"]["inputs"].(map[string]any) + if ks["seed"].(float64) != 555 { + t.Errorf("KSampler seed = %v, want 555", ks["seed"]) + } + if ks["steps"].(float64) != 11 { + t.Errorf("KSampler steps = %v, want 11", ks["steps"]) + } + if ks["sampler_name"].(string) != "dpmpp_2m" { + t.Errorf("sampler_name = %v", ks["sampler_name"]) + } + if ks["scheduler"].(string) != "karras" { + t.Errorf("scheduler = %v", ks["scheduler"]) + } + latent := sent.Prompt["27"]["inputs"].(map[string]any) + if latent["width"].(float64) != 768 || latent["height"].(float64) != 512 { + t.Errorf("EmptySD3LatentImage size = %vx%v", latent["width"], latent["height"]) + } + unet := sent.Prompt["12"]["inputs"].(map[string]any) + if unet["unet_name"].(string) != "custom.safetensors" { + t.Errorf("unet_name = %v", unet["unet_name"]) + } + neg := sent.Prompt["13"]["inputs"].(map[string]any) + if neg["text"].(string) != "blurry" { + t.Errorf("negative prompt not threaded: %v", neg["text"]) + } + if !strings.HasPrefix(sent.ClientID, "imagen-") && sent.ClientID == "" { + t.Errorf("client_id should be set: %q", sent.ClientID) + } +} + +func TestComfyTypeIsRegistered(t *testing.T) { + if !Default.Has(ComfyType) { + t.Errorf("comfyui type not registered in Default") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 973b087..64ddafa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -95,7 +95,7 @@ const Sample = `# imagen.yaml — config for the imagen CLI. # implementing the Backend interface, registering its type name, and listing # an instance here. -default_backend: mock +default_backend: flux-schnell-local output: directory: ~/Pictures/imagen @@ -103,14 +103,18 @@ output: write_metadata_json: true backends: - mock: - type: mock - flux-schnell-local: type: comfyui base_url: http://mrock:8188 + # Filename of the unet checkpoint inside the ComfyUI server's + # models/unet/ directory. See docs/setup-comfyui-mrock.md. model: flux1-schnell.safetensors default_steps: 4 + default_sampler: euler + default_scheduler: simple + + mock: + type: mock flux-dev-replicate: type: replicate diff --git a/internal/config/config_test.go b/internal/config/config_test.go index afba5ee..f500ced 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -16,7 +16,7 @@ func TestLoadAndValidate(t *testing.T) { if err != nil { t.Fatalf("Load: %v", err) } - if cfg.DefaultBackend != "mock" { + if cfg.DefaultBackend != "flux-schnell-local" { t.Errorf("default = %q", cfg.DefaultBackend) } mock, ok := cfg.Backends["mock"] @@ -30,9 +30,15 @@ func TestLoadAndValidate(t *testing.T) { if !ok { t.Fatalf("flux backend missing") } + if flux.Type != "comfyui" { + t.Errorf("flux type = %q", flux.Type) + } if flux.Raw["base_url"] != "http://mrock:8188" { t.Errorf("flux base_url = %v", flux.Raw["base_url"]) } + if flux.Raw["model"] != "flux1-schnell.safetensors" { + t.Errorf("flux model = %v", flux.Raw["model"]) + } } func TestValidateRejectsUnknownDefault(t *testing.T) {