Files
ImaGen/internal/backend/comfyui.go
mAi 127bbf3ed5 mAi: #2 - phase 2 ComfyUI Go adapter, tests, config sample
internal/backend/comfyui.go implements the Backend interface against
ComfyUI's /prompt + /history + /view HTTP API. Workflow is the canonical
FLUX.1 schnell shape — UNETLoader + DualCLIPLoader (clip_l + t5xxl fp8) +
VAELoader + ModelSamplingFlux + KSampler — assembled as a Go map per
request so Width / Height / Seed / Steps / sampler / scheduler all flow
into the right node inputs.

Resilience: one retry on /prompt 5xx and transient network errors, no
retry on 4xx. Connection-refused / timeouts surface a 'boot-whitetower
mrock' hint. node_errors mentioning a missing unet point users at
docs/setup-comfyui-mrock.md (matches both the 4xx and 200-with-errors
shapes ComfyUI uses across versions).

Result.Metadata carries model, seed_used, latency_ms, steps, sampler,
scheduler, width, height, prompt_id, client_id, plus best-effort
vram_used_mib pulled from /system_stats post-gen.

Tests use httptest with poll interval squashed to 1ms — no real mRock
dependency. Coverage: happy path, defaults, retry-once on 5xx, give-up
after two 5xx, no-retry on 4xx, missing-model hint (both 4xx and
200+node_errors paths), history-error surfaced, /view 4xx, unreachable
host, ctx cancel during poll, workflow-shape assertion, registration.

Config sample: flux-schnell-local is now default_backend; the user-facing
block names the unet file by basename (the mapping into models/unet/ is
the server's convention, captured in docs/setup-comfyui-mrock.md from
phase 1).

Smoke verified end-to-end: imagen generate ... --backend
flux-schnell-local --size 1024x1024 --output /tmp/cat-via-cli.png on
mRock returned a 1024x1024 PNG of a cat in a fishbowl in 10.3s with a
sidecar carrying seed + latency_ms + the rest of the metadata.
2026-05-08 16:59:21 +02:00

558 lines
15 KiB
Go

package backend
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// ComfyType is the type-name adapters register under for ComfyUI instances.
const ComfyType = "comfyui"
// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history`
// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from
// the values in Request.
//
// Concurrency: a single Comfy is safe to share across goroutines as long as
// the underlying http.Client is. Generate does not hold long-lived state.
type Comfy struct {
instance string
base string
model string
vae string
clipL string
clipT5 string
dtype string
defaultSteps int
defaultSampler string
defaultScheduler string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
clientIDFn func() string
}
// NewComfy is the registry constructor. cfg is the adapter's slice of
// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX
// schnell defaults.
func NewComfy(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("comfyui: empty instance name")
}
base := strings.TrimRight(getString(cfg, "base_url", ""), "/")
if base == "" {
return nil, fmt.Errorf("comfyui[%s]: base_url is required", name)
}
if _, err := url.Parse(base); err != nil {
return nil, fmt.Errorf("comfyui[%s]: base_url %q invalid: %w", name, base, err)
}
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("comfyui[%s]: model is required", name)
}
c := &Comfy{
instance: name,
base: base,
model: model,
vae: getString(cfg, "vae", "ae.safetensors"),
clipL: getString(cfg, "clip_l", "clip_l.safetensors"),
clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"),
dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"),
defaultSteps: getInt(cfg, "default_steps", 4),
defaultSampler: getString(cfg, "default_sampler", "euler"),
defaultScheduler: getString(cfg, "default_scheduler", "simple"),
httpClient: &http.Client{Timeout: 60 * time.Second},
pollInterval: 250 * time.Millisecond,
pollTimeout: 120 * time.Second,
randSeed: cryptoSeed,
clientIDFn: randClientID,
}
return c, nil
}
// Name returns the instance name from imagen.yaml.
func (c *Comfy) Name() string { return c.instance }
// Generate submits one workflow to ComfyUI, waits for it to render, and
// returns the resulting PNG.
func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
steps := orDefaultInt(req.Steps, c.defaultSteps)
sampler := c.defaultSampler
scheduler := c.defaultScheduler
if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" {
sampler = v
}
if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" {
scheduler = v
}
seed := req.Seed
if seed == 0 {
seed = c.randSeed()
}
workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler)
clientID := c.clientIDFn()
start := time.Now()
promptID, err := c.submitPrompt(ctx, workflow, clientID)
if err != nil {
return nil, err
}
filename, err := c.waitForCompletion(ctx, promptID)
if err != nil {
return nil, err
}
imgBytes, err := c.fetchImage(ctx, filename)
if err != nil {
return nil, err
}
latencyMs := time.Since(start).Milliseconds()
meta := map[string]any{
"backend": c.instance,
"backend_type": ComfyType,
"model": c.model,
"seed": seed,
"steps": steps,
"sampler": sampler,
"scheduler": scheduler,
"width": width,
"height": height,
"latency_ms": latencyMs,
"prompt_id": promptID,
"client_id": clientID,
}
if vram := c.vramUsedMiB(ctx); vram > 0 {
meta["vram_used_mib"] = vram
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: "image/png",
Metadata: meta,
}, nil
}
// submitPrompt POSTs the workflow and extracts the prompt_id.
//
// Retries once on a 5xx or transient network error. 4xx responses are not
// retried — they are treated as configuration bugs (missing model, bad
// workflow shape, etc.) and surfaced with a hint pointing at the docs when
// the body matches a known pattern.
func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clientID string) (string, error) {
body, err := json.Marshal(map[string]any{
"prompt": workflow,
"client_id": clientID,
})
if err != nil {
return "", fmt.Errorf("comfyui: marshal workflow: %w", err)
}
var lastErr error
for attempt := range 2 {
if attempt > 0 {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(time.Second):
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/prompt", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = c.connError(err)
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return parsePromptID(respBody, c.model)
case resp.StatusCode >= 500:
lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody))
continue
default:
return "", c.classifyBadRequest(resp.StatusCode, respBody)
}
}
return "", lastErr
}
// waitForCompletion polls /history/{id} until the prompt finishes and
// returns the filename of the produced image.
func (c *Comfy) waitForCompletion(ctx context.Context, promptID string) (string, error) {
deadline := time.Now().Add(c.pollTimeout)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
if time.Now().After(deadline) {
return "", fmt.Errorf("comfyui: prompt %s did not complete within %s", promptID, c.pollTimeout)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/history/"+promptID, nil)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", c.connError(err)
}
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("comfyui /history/%s %d: %s", promptID, resp.StatusCode, snip(body))
}
filename, done, err := parseHistory(body, promptID)
if err != nil {
return "", err
}
if done {
return filename, nil
}
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(c.pollInterval):
}
}
}
// fetchImage downloads the produced image bytes via /view.
func (c *Comfy) fetchImage(ctx context.Context, filename string) ([]byte, error) {
q := url.Values{
"filename": {filename},
"type": {"output"},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/view?"+q.Encode(), nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, c.connError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("comfyui /view %d: %s", resp.StatusCode, snip(body))
}
return io.ReadAll(resp.Body)
}
// vramUsedMiB returns total - free VRAM on device 0 from /system_stats, or
// 0 if the endpoint isn't available. Best-effort, never an error.
func (c *Comfy) vramUsedMiB(ctx context.Context) int64 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/system_stats", nil)
if err != nil {
return 0
}
resp, err := c.httpClient.Do(req)
if err != nil {
return 0
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0
}
var s struct {
Devices []struct {
VRAMTotal int64 `json:"vram_total"`
VRAMFree int64 `json:"vram_free"`
} `json:"devices"`
}
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
return 0
}
if len(s.Devices) == 0 {
return 0
}
used := s.Devices[0].VRAMTotal - s.Devices[0].VRAMFree
if used < 0 {
return 0
}
return used / (1024 * 1024)
}
// connError translates a Go networking error into a user-actionable message,
// pointing at the boot-whitetower script when mRock looks asleep.
func (c *Comfy) connError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return err
}
msg := err.Error()
var opErr *net.OpError
asOp := errors.As(err, &opErr)
switch {
case asOp,
strings.Contains(msg, "connection refused"),
strings.Contains(msg, "no such host"),
strings.Contains(msg, "no route to host"),
strings.Contains(msg, "network is unreachable"),
strings.Contains(msg, "i/o timeout"):
return fmt.Errorf("comfyui at %s unreachable (%v) — if mRock is asleep, run: boot-whitetower mrock", c.base, err)
}
return fmt.Errorf("comfyui at %s: %w", c.base, err)
}
// classifyBadRequest interprets a 4xx body. Some ComfyUI builds use 400 for
// workflow-validation failures and put the diagnostics in node_errors; older
// builds use 200 + node_errors. This handles the 4xx flavour.
func (c *Comfy) classifyBadRequest(status int, body []byte) error {
if hint, ok := missingModelHint(body, c.model); ok {
return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint)
}
return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body))
}
// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow,
// node-IDs matching the upstream "flux-schnell" template so anyone debugging
// in the ComfyUI UI sees a familiar shape.
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any {
return map[string]any{
"6": map[string]any{
"class_type": "CLIPTextEncode",
"inputs": map[string]any{
"text": prompt,
"clip": []any{"11", 0},
},
},
"8": map[string]any{
"class_type": "VAEDecode",
"inputs": map[string]any{
"samples": []any{"31", 0},
"vae": []any{"10", 0},
},
},
"9": map[string]any{
"class_type": "SaveImage",
"inputs": map[string]any{
"filename_prefix": "imagen",
"images": []any{"8", 0},
},
},
"10": map[string]any{
"class_type": "VAELoader",
"inputs": map[string]any{"vae_name": c.vae},
},
"11": map[string]any{
"class_type": "DualCLIPLoader",
"inputs": map[string]any{
"clip_name1": c.clipT5,
"clip_name2": c.clipL,
"type": "flux",
},
},
"12": map[string]any{
"class_type": "UNETLoader",
"inputs": map[string]any{
"unet_name": c.model,
"weight_dtype": c.dtype,
},
},
"13": map[string]any{
"class_type": "CLIPTextEncode",
"inputs": map[string]any{
"text": negative,
"clip": []any{"11", 0},
},
},
"27": map[string]any{
"class_type": "EmptySD3LatentImage",
"inputs": map[string]any{
"width": w,
"height": h,
"batch_size": 1,
},
},
"30": map[string]any{
"class_type": "ModelSamplingFlux",
"inputs": map[string]any{
"model": []any{"12", 0},
"max_shift": 1.15,
"base_shift": 0.5,
"width": w,
"height": h,
},
},
"31": map[string]any{
"class_type": "KSampler",
"inputs": map[string]any{
"model": []any{"30", 0},
"seed": seed,
"steps": steps,
"cfg": 1.0,
"sampler_name": sampler,
"scheduler": scheduler,
"denoise": 1.0,
"positive": []any{"6", 0},
"negative": []any{"13", 0},
"latent_image": []any{"27", 0},
},
},
}
}
// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a
// validation failure and stuffs node_errors in the body — this function
// turns that into the same user-facing error as a 4xx with the same body.
func parsePromptID(body []byte, model string) (string, error) {
var resp struct {
PromptID string `json:"prompt_id"`
NodeErrors map[string]any `json:"node_errors"`
Error json.RawMessage `json:"error"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("comfyui /prompt: parse response: %w (body: %s)", err, snip(body))
}
if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 {
if hint, ok := missingModelHint(body, model); ok {
return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint)
}
return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body))
}
if resp.PromptID == "" {
return "", fmt.Errorf("comfyui /prompt: empty prompt_id (body: %s)", snip(body))
}
return resp.PromptID, nil
}
// parseHistory inspects a /history/{id} body and returns either the produced
// filename + done=true, or done=false to signal "keep polling".
func parseHistory(body []byte, promptID string) (string, bool, error) {
var entries map[string]struct {
Status struct {
Completed bool `json:"completed"`
StatusStr string `json:"status_str"`
} `json:"status"`
Outputs map[string]struct {
Images []struct {
Filename string `json:"filename"`
Subfolder string `json:"subfolder"`
Type string `json:"type"`
} `json:"images"`
} `json:"outputs"`
}
if err := json.Unmarshal(body, &entries); err != nil {
return "", false, fmt.Errorf("comfyui /history: parse: %w (body: %s)", err, snip(body))
}
e, ok := entries[promptID]
if !ok {
return "", false, nil
}
if e.Status.StatusStr == "error" {
return "", false, fmt.Errorf("comfyui prompt %s errored: %s", promptID, snip(body))
}
if !e.Status.Completed {
return "", false, nil
}
for _, out := range e.Outputs {
if len(out.Images) > 0 {
return out.Images[0].Filename, true, nil
}
}
return "", true, fmt.Errorf("comfyui prompt %s completed but produced no images", promptID)
}
// missingModelHint returns a user-actionable message when the response body
// indicates the configured unet model isn't loaded on the server. ComfyUI
// uses both the human-readable "Value not in list" message and the enum
// "value_not_in_list" type — match either.
func missingModelHint(body []byte, model string) (string, bool) {
s := string(body)
hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list")
if hasMarker && strings.Contains(s, "unet_name") {
return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true
}
return "", false
}
func cryptoSeed() int64 {
var b [8]byte
if _, err := rand.Read(b[:]); err != nil {
return time.Now().UnixNano()
}
return int64(binary.BigEndian.Uint64(b[:]) >> 1)
}
func randClientID() string {
var b [8]byte
_, _ = rand.Read(b[:])
return fmt.Sprintf("imagen-%x", b)
}
func getString(m map[string]any, k, def string) string {
if v, ok := m[k].(string); ok && v != "" {
return v
}
return def
}
func getInt(m map[string]any, k string, def int) int {
if v, ok := m[k]; ok {
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
}
return def
}
func orDefaultInt(v, def int) int {
if v == 0 {
return def
}
return v
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}
func init() {
Register(ComfyType, NewComfy)
}