package backend import ( "bytes" "context" "crypto/rand" "encoding/binary" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "strings" "time" ) // ComfyType is the type-name adapters register under for ComfyUI instances. const ComfyType = "comfyui" // Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history` // + `/view` HTTP API and submits a workflow built by substituting Request // values into a JSON template (bundled under internal/backend/workflows/ or // loaded from a filesystem path). // // Concurrency: a single Comfy is safe to share across goroutines as long as // the underlying http.Client is. Generate does not hold long-lived state. type Comfy struct { instance string base string workflow string // rawCfg keeps the original yaml block (minus framework keys) so we can // expose every user-defined string/number as a workflow substitution // without enumerating each per-model knob in Go. Empty values still get // a substitution entry so a template can reference ${negative} when the // request didn't pass one. rawCfg map[string]any defaultSteps int defaultSampler string defaultScheduler string defaultCFG float64 httpClient *http.Client pollInterval time.Duration pollTimeout time.Duration // Hooks for tests; production paths use the package-level defaults. randSeed func() int64 clientIDFn func() string } // NewComfy is the registry constructor. cfg is the adapter's slice of // imagen.yaml. // // Required keys: base_url, model. // Optional keys: workflow (defaults to "flux1-schnell" for back-compat with // existing configs), default_steps, default_sampler, default_scheduler, // default_cfg, plus any template-specific knobs (vae, clip, clip_l, // clip_t5, dtype, shift, guidance, …) the chosen workflow references. func NewComfy(name string, cfg map[string]any) (Backend, error) { if name == "" { return nil, fmt.Errorf("comfyui: empty instance name") } if cfg == nil { cfg = map[string]any{} } base := strings.TrimRight(getString(cfg, "base_url", ""), "/") if base == "" { return nil, fmt.Errorf("comfyui[%s]: base_url is required", name) } if _, err := url.Parse(base); err != nil { return nil, fmt.Errorf("comfyui[%s]: base_url %q invalid: %w", name, base, err) } model := getString(cfg, "model", "") if model == "" { return nil, fmt.Errorf("comfyui[%s]: model is required", name) } workflow := getString(cfg, "workflow", "flux1-schnell") // Fail fast on a bad workflow ref so users see the error at startup, // not on first /prompt submission. if _, err := LoadWorkflowTemplate(workflow); err != nil { return nil, fmt.Errorf("comfyui[%s]: %w", name, err) } c := &Comfy{ instance: name, base: base, workflow: workflow, rawCfg: cfg, defaultSteps: getInt(cfg, "default_steps", 4), defaultSampler: getString(cfg, "default_sampler", "euler"), defaultScheduler: getString(cfg, "default_scheduler", "simple"), defaultCFG: getFloat(cfg, "default_cfg", 1.0), httpClient: &http.Client{Timeout: 60 * time.Second}, pollInterval: 250 * time.Millisecond, pollTimeout: 300 * time.Second, randSeed: cryptoSeed, clientIDFn: randClientID, } return c, nil } // Name returns the instance name from imagen.yaml. func (c *Comfy) Name() string { return c.instance } // Generate submits one workflow to ComfyUI, waits for it to render, and // returns the resulting PNG. func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) { width := orDefaultInt(req.Width, 1024) height := orDefaultInt(req.Height, 1024) steps := orDefaultInt(req.Steps, c.defaultSteps) sampler := c.defaultSampler scheduler := c.defaultScheduler cfg := c.defaultCFG if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" { sampler = v } if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" { scheduler = v } if v, ok := req.BackendOpts["cfg"].(float64); ok && v > 0 { cfg = v } seed := req.Seed if seed == 0 { seed = c.randSeed() } workflow, err := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler, cfg) if err != nil { return nil, fmt.Errorf("comfyui[%s]: build workflow: %w", c.instance, err) } clientID := c.clientIDFn() start := time.Now() promptID, err := c.submitPrompt(ctx, workflow, clientID) if err != nil { return nil, err } filename, err := c.waitForCompletion(ctx, promptID) if err != nil { return nil, err } imgBytes, err := c.fetchImage(ctx, filename) if err != nil { return nil, err } latencyMs := time.Since(start).Milliseconds() model := getString(c.rawCfg, "model", "") meta := map[string]any{ "backend": c.instance, "backend_type": ComfyType, "workflow": c.workflow, "model": model, "seed": seed, "steps": steps, "sampler": sampler, "scheduler": scheduler, "cfg": cfg, "width": width, "height": height, "latency_ms": latencyMs, "prompt_id": promptID, "client_id": clientID, } if vram := c.vramUsedMiB(ctx); vram > 0 { meta["vram_used_mib"] = vram } return &Result{ ImageReader: io.NopCloser(bytes.NewReader(imgBytes)), MimeType: "image/png", Metadata: meta, }, nil } // submitPrompt POSTs the workflow and extracts the prompt_id. // // Retries once on a 5xx or transient network error. 4xx responses are not // retried — they are treated as configuration bugs (missing model, bad // workflow shape, etc.) and surfaced with a hint pointing at the docs when // the body matches a known pattern. func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clientID string) (string, error) { body, err := json.Marshal(map[string]any{ "prompt": workflow, "client_id": clientID, }) if err != nil { return "", fmt.Errorf("comfyui: marshal workflow: %w", err) } model := getString(c.rawCfg, "model", "") var lastErr error for attempt := range 2 { if attempt > 0 { select { case <-ctx.Done(): return "", ctx.Err() case <-time.After(time.Second): } } req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/prompt", bytes.NewReader(body)) if err != nil { return "", err } req.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(req) if err != nil { lastErr = c.connError(err) continue } respBody, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() switch { case resp.StatusCode >= 200 && resp.StatusCode < 300: return parsePromptID(respBody, model) case resp.StatusCode >= 500: lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody)) continue default: return "", c.classifyBadRequest(resp.StatusCode, respBody) } } return "", lastErr } // waitForCompletion polls /history/{id} until the prompt finishes and // returns the filename of the produced image. func (c *Comfy) waitForCompletion(ctx context.Context, promptID string) (string, error) { deadline := time.Now().Add(c.pollTimeout) for { select { case <-ctx.Done(): return "", ctx.Err() default: } if time.Now().After(deadline) { return "", fmt.Errorf("comfyui: prompt %s did not complete within %s", promptID, c.pollTimeout) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/history/"+promptID, nil) if err != nil { return "", err } resp, err := c.httpClient.Do(req) if err != nil { return "", c.connError(err) } body, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("comfyui /history/%s %d: %s", promptID, resp.StatusCode, snip(body)) } filename, done, err := parseHistory(body, promptID) if err != nil { return "", err } if done { return filename, nil } select { case <-ctx.Done(): return "", ctx.Err() case <-time.After(c.pollInterval): } } } // fetchImage downloads the produced image bytes via /view. func (c *Comfy) fetchImage(ctx context.Context, filename string) ([]byte, error) { q := url.Values{ "filename": {filename}, "type": {"output"}, } req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/view?"+q.Encode(), nil) if err != nil { return nil, err } resp, err := c.httpClient.Do(req) if err != nil { return nil, c.connError(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("comfyui /view %d: %s", resp.StatusCode, snip(body)) } return io.ReadAll(resp.Body) } // vramUsedMiB returns total - free VRAM on device 0 from /system_stats, or // 0 if the endpoint isn't available. Best-effort, never an error. func (c *Comfy) vramUsedMiB(ctx context.Context) int64 { req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/system_stats", nil) if err != nil { return 0 } resp, err := c.httpClient.Do(req) if err != nil { return 0 } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return 0 } var s struct { Devices []struct { VRAMTotal int64 `json:"vram_total"` VRAMFree int64 `json:"vram_free"` } `json:"devices"` } if err := json.NewDecoder(resp.Body).Decode(&s); err != nil { return 0 } if len(s.Devices) == 0 { return 0 } used := s.Devices[0].VRAMTotal - s.Devices[0].VRAMFree if used < 0 { return 0 } return used / (1024 * 1024) } // connError translates a Go networking error into a user-actionable message, // pointing at the boot-whitetower script when mRock looks asleep. func (c *Comfy) connError(err error) error { if err == nil { return nil } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return err } msg := err.Error() var opErr *net.OpError asOp := errors.As(err, &opErr) switch { case asOp, strings.Contains(msg, "connection refused"), strings.Contains(msg, "no such host"), strings.Contains(msg, "no route to host"), strings.Contains(msg, "network is unreachable"), strings.Contains(msg, "i/o timeout"): return fmt.Errorf("comfyui at %s unreachable (%v) — if mRock is asleep, run: boot-whitetower mrock", c.base, err) } return fmt.Errorf("comfyui at %s: %w", c.base, err) } // classifyBadRequest interprets a 4xx body. Some ComfyUI builds use 400 for // workflow-validation failures and put the diagnostics in node_errors; older // builds use 200 + node_errors. This handles the 4xx flavour. func (c *Comfy) classifyBadRequest(status int, body []byte) error { model := getString(c.rawCfg, "model", "") if hint, ok := missingModelHint(body, model); ok { return fmt.Errorf("comfyui /prompt %d: %s — see docs/backends.md", status, hint) } return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body)) } // buildWorkflow loads the configured workflow template and substitutes the // per-call placeholders (prompt, seed, sampler, …) plus any string/number // fields the user defined in the yaml block. The set of placeholder keys // that aren't in `subs` produces an error from SubstituteWorkflow. func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string, cfg float64) (map[string]any, error) { wf, err := LoadWorkflowTemplate(c.workflow) if err != nil { return nil, err } subs := map[string]any{ "prompt": prompt, "negative": negative, "width": w, "height": h, "seed": seed, "steps": steps, "sampler": sampler, "scheduler": scheduler, "cfg": cfg, } // Surface every scalar field from the yaml block so per-template knobs // (vae, clip, clip_l, clip_t5, dtype, shift, guidance, …) work without // adapter-code changes. Framework keys are excluded. for k, v := range c.rawCfg { switch k { case "type", "base_url", "workflow", "default_steps", "default_sampler", "default_scheduler", "default_cfg": continue } if _, alreadySet := subs[k]; alreadySet { // A per-call var (e.g. ${prompt}) beats anything yaml put under // the same key — yaml can't shadow request-derived values. continue } switch v := v.(type) { case string, int, int64, float64, bool: subs[k] = v } } // Provide sensible defaults for common optional knobs so a workflow that // references one of these doesn't fail substitution when the user // didn't override it in yaml. Extra keys are ignored if the workflow // doesn't reference them, so it's safe to always set the lot. defaults := map[string]any{ "vae": "ae.safetensors", "clip_l": "clip_l.safetensors", "clip_t5": "t5xxl_fp8_e4m3fn.safetensors", "clip": "qwen_3_4b.safetensors", "dtype": "fp8_e4m3fn", "guidance": 4.0, "shift": 3.0, } for k, v := range defaults { if _, ok := subs[k]; !ok { subs[k] = v } } if _, err := SubstituteWorkflow(wf, subs); err != nil { return nil, err } return wf, nil } // parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a // validation failure and stuffs node_errors in the body — this function // turns that into the same user-facing error as a 4xx with the same body. func parsePromptID(body []byte, model string) (string, error) { var resp struct { PromptID string `json:"prompt_id"` NodeErrors map[string]any `json:"node_errors"` Error json.RawMessage `json:"error"` } if err := json.Unmarshal(body, &resp); err != nil { return "", fmt.Errorf("comfyui /prompt: parse response: %w (body: %s)", err, snip(body)) } if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 { if hint, ok := missingModelHint(body, model); ok { return "", fmt.Errorf("comfyui /prompt: %s — see docs/backends.md", hint) } return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body)) } if resp.PromptID == "" { return "", fmt.Errorf("comfyui /prompt: empty prompt_id (body: %s)", snip(body)) } return resp.PromptID, nil } // parseHistory inspects a /history/{id} body and returns either the produced // filename + done=true, or done=false to signal "keep polling". func parseHistory(body []byte, promptID string) (string, bool, error) { var entries map[string]struct { Status struct { Completed bool `json:"completed"` StatusStr string `json:"status_str"` } `json:"status"` Outputs map[string]struct { Images []struct { Filename string `json:"filename"` Subfolder string `json:"subfolder"` Type string `json:"type"` } `json:"images"` } `json:"outputs"` } if err := json.Unmarshal(body, &entries); err != nil { return "", false, fmt.Errorf("comfyui /history: parse: %w (body: %s)", err, snip(body)) } e, ok := entries[promptID] if !ok { return "", false, nil } if e.Status.StatusStr == "error" { return "", false, fmt.Errorf("comfyui prompt %s errored: %s", promptID, snip(body)) } if !e.Status.Completed { return "", false, nil } for _, out := range e.Outputs { if len(out.Images) > 0 { return out.Images[0].Filename, true, nil } } return "", true, fmt.Errorf("comfyui prompt %s completed but produced no images", promptID) } // missingModelHint returns a user-actionable message when the response body // indicates the configured unet/checkpoint model isn't loaded on the server. // ComfyUI uses both the human-readable "Value not in list" message and the // enum "value_not_in_list" type — match either. func missingModelHint(body []byte, model string) (string, bool) { s := string(body) hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list") if !hasMarker { return "", false } if strings.Contains(s, "unet_name") { return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true } if strings.Contains(s, "ckpt_name") { return fmt.Sprintf("checkpoint %q not present in the ComfyUI server's models/checkpoints/", model), true } return "", false } func cryptoSeed() int64 { var b [8]byte if _, err := rand.Read(b[:]); err != nil { return time.Now().UnixNano() } return int64(binary.BigEndian.Uint64(b[:]) >> 1) } func randClientID() string { var b [8]byte _, _ = rand.Read(b[:]) return fmt.Sprintf("imagen-%x", b) } func getString(m map[string]any, k, def string) string { if v, ok := m[k].(string); ok && v != "" { return v } return def } func getInt(m map[string]any, k string, def int) int { if v, ok := m[k]; ok { switch n := v.(type) { case int: return n case int64: return int(n) case float64: return int(n) } } return def } func getFloat(m map[string]any, k string, def float64) float64 { if v, ok := m[k]; ok { switch n := v.(type) { case float64: return n case float32: return float64(n) case int: return float64(n) case int64: return float64(n) } } return def } func orDefaultInt(v, def int) int { if v == 0 { return def } return v } func snip(b []byte) string { const max = 500 s := strings.TrimSpace(string(b)) if len(s) > max { s = s[:max] + "..." } return s } func init() { Register(ComfyType, NewComfy) }