mAi: #2 - phase 2 ComfyUI Go adapter, tests, config sample
internal/backend/comfyui.go implements the Backend interface against ComfyUI's /prompt + /history + /view HTTP API. Workflow is the canonical FLUX.1 schnell shape — UNETLoader + DualCLIPLoader (clip_l + t5xxl fp8) + VAELoader + ModelSamplingFlux + KSampler — assembled as a Go map per request so Width / Height / Seed / Steps / sampler / scheduler all flow into the right node inputs. Resilience: one retry on /prompt 5xx and transient network errors, no retry on 4xx. Connection-refused / timeouts surface a 'boot-whitetower mrock' hint. node_errors mentioning a missing unet point users at docs/setup-comfyui-mrock.md (matches both the 4xx and 200-with-errors shapes ComfyUI uses across versions). Result.Metadata carries model, seed_used, latency_ms, steps, sampler, scheduler, width, height, prompt_id, client_id, plus best-effort vram_used_mib pulled from /system_stats post-gen. Tests use httptest with poll interval squashed to 1ms — no real mRock dependency. Coverage: happy path, defaults, retry-once on 5xx, give-up after two 5xx, no-retry on 4xx, missing-model hint (both 4xx and 200+node_errors paths), history-error surfaced, /view 4xx, unreachable host, ctx cancel during poll, workflow-shape assertion, registration. Config sample: flux-schnell-local is now default_backend; the user-facing block names the unet file by basename (the mapping into models/unet/ is the server's convention, captured in docs/setup-comfyui-mrock.md from phase 1). Smoke verified end-to-end: imagen generate ... --backend flux-schnell-local --size 1024x1024 --output /tmp/cat-via-cli.png on mRock returned a 1024x1024 PNG of a cat in a fishbowl in 10.3s with a sidecar carrying seed + latency_ms + the rest of the metadata.
This commit is contained in:
557
internal/backend/comfyui.go
Normal file
557
internal/backend/comfyui.go
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ComfyType is the type-name adapters register under for ComfyUI instances.
|
||||||
|
const ComfyType = "comfyui"
|
||||||
|
|
||||||
|
// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history`
|
||||||
|
// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from
|
||||||
|
// the values in Request.
|
||||||
|
//
|
||||||
|
// Concurrency: a single Comfy is safe to share across goroutines as long as
|
||||||
|
// the underlying http.Client is. Generate does not hold long-lived state.
|
||||||
|
type Comfy struct {
|
||||||
|
instance string
|
||||||
|
|
||||||
|
base string
|
||||||
|
model string
|
||||||
|
vae string
|
||||||
|
clipL string
|
||||||
|
clipT5 string
|
||||||
|
dtype string
|
||||||
|
|
||||||
|
defaultSteps int
|
||||||
|
defaultSampler string
|
||||||
|
defaultScheduler string
|
||||||
|
|
||||||
|
httpClient *http.Client
|
||||||
|
pollInterval time.Duration
|
||||||
|
pollTimeout time.Duration
|
||||||
|
|
||||||
|
// Hooks for tests; production paths use the package-level defaults.
|
||||||
|
randSeed func() int64
|
||||||
|
clientIDFn func() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewComfy is the registry constructor. cfg is the adapter's slice of
|
||||||
|
// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX
|
||||||
|
// schnell defaults.
|
||||||
|
func NewComfy(name string, cfg map[string]any) (Backend, error) {
|
||||||
|
if name == "" {
|
||||||
|
return nil, fmt.Errorf("comfyui: empty instance name")
|
||||||
|
}
|
||||||
|
base := strings.TrimRight(getString(cfg, "base_url", ""), "/")
|
||||||
|
if base == "" {
|
||||||
|
return nil, fmt.Errorf("comfyui[%s]: base_url is required", name)
|
||||||
|
}
|
||||||
|
if _, err := url.Parse(base); err != nil {
|
||||||
|
return nil, fmt.Errorf("comfyui[%s]: base_url %q invalid: %w", name, base, err)
|
||||||
|
}
|
||||||
|
model := getString(cfg, "model", "")
|
||||||
|
if model == "" {
|
||||||
|
return nil, fmt.Errorf("comfyui[%s]: model is required", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Comfy{
|
||||||
|
instance: name,
|
||||||
|
base: base,
|
||||||
|
model: model,
|
||||||
|
|
||||||
|
vae: getString(cfg, "vae", "ae.safetensors"),
|
||||||
|
clipL: getString(cfg, "clip_l", "clip_l.safetensors"),
|
||||||
|
clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"),
|
||||||
|
dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"),
|
||||||
|
|
||||||
|
defaultSteps: getInt(cfg, "default_steps", 4),
|
||||||
|
defaultSampler: getString(cfg, "default_sampler", "euler"),
|
||||||
|
defaultScheduler: getString(cfg, "default_scheduler", "simple"),
|
||||||
|
|
||||||
|
httpClient: &http.Client{Timeout: 60 * time.Second},
|
||||||
|
pollInterval: 250 * time.Millisecond,
|
||||||
|
pollTimeout: 120 * time.Second,
|
||||||
|
|
||||||
|
randSeed: cryptoSeed,
|
||||||
|
clientIDFn: randClientID,
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the instance name from imagen.yaml.
|
||||||
|
func (c *Comfy) Name() string { return c.instance }
|
||||||
|
|
||||||
|
// Generate submits one workflow to ComfyUI, waits for it to render, and
|
||||||
|
// returns the resulting PNG.
|
||||||
|
func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
|
||||||
|
width := orDefaultInt(req.Width, 1024)
|
||||||
|
height := orDefaultInt(req.Height, 1024)
|
||||||
|
steps := orDefaultInt(req.Steps, c.defaultSteps)
|
||||||
|
|
||||||
|
sampler := c.defaultSampler
|
||||||
|
scheduler := c.defaultScheduler
|
||||||
|
if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" {
|
||||||
|
sampler = v
|
||||||
|
}
|
||||||
|
if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" {
|
||||||
|
scheduler = v
|
||||||
|
}
|
||||||
|
|
||||||
|
seed := req.Seed
|
||||||
|
if seed == 0 {
|
||||||
|
seed = c.randSeed()
|
||||||
|
}
|
||||||
|
|
||||||
|
workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler)
|
||||||
|
clientID := c.clientIDFn()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
promptID, err := c.submitPrompt(ctx, workflow, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filename, err := c.waitForCompletion(ctx, promptID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
imgBytes, err := c.fetchImage(ctx, filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
latencyMs := time.Since(start).Milliseconds()
|
||||||
|
|
||||||
|
meta := map[string]any{
|
||||||
|
"backend": c.instance,
|
||||||
|
"backend_type": ComfyType,
|
||||||
|
"model": c.model,
|
||||||
|
"seed": seed,
|
||||||
|
"steps": steps,
|
||||||
|
"sampler": sampler,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"latency_ms": latencyMs,
|
||||||
|
"prompt_id": promptID,
|
||||||
|
"client_id": clientID,
|
||||||
|
}
|
||||||
|
if vram := c.vramUsedMiB(ctx); vram > 0 {
|
||||||
|
meta["vram_used_mib"] = vram
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Result{
|
||||||
|
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
|
||||||
|
MimeType: "image/png",
|
||||||
|
Metadata: meta,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// submitPrompt POSTs the workflow and extracts the prompt_id.
|
||||||
|
//
|
||||||
|
// Retries once on a 5xx or transient network error. 4xx responses are not
|
||||||
|
// retried — they are treated as configuration bugs (missing model, bad
|
||||||
|
// workflow shape, etc.) and surfaced with a hint pointing at the docs when
|
||||||
|
// the body matches a known pattern.
|
||||||
|
func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clientID string) (string, error) {
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"prompt": workflow,
|
||||||
|
"client_id": clientID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("comfyui: marshal workflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := range 2 {
|
||||||
|
if attempt > 0 {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/prompt", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = c.connError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
switch {
|
||||||
|
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
||||||
|
return parsePromptID(respBody, c.model)
|
||||||
|
case resp.StatusCode >= 500:
|
||||||
|
lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody))
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return "", c.classifyBadRequest(resp.StatusCode, respBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForCompletion polls /history/{id} until the prompt finishes and
|
||||||
|
// returns the filename of the produced image.
|
||||||
|
func (c *Comfy) waitForCompletion(ctx context.Context, promptID string) (string, error) {
|
||||||
|
deadline := time.Now().Add(c.pollTimeout)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return "", fmt.Errorf("comfyui: prompt %s did not complete within %s", promptID, c.pollTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/history/"+promptID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", c.connError(err)
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("comfyui /history/%s %d: %s", promptID, resp.StatusCode, snip(body))
|
||||||
|
}
|
||||||
|
filename, done, err := parseHistory(body, promptID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
return filename, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case <-time.After(c.pollInterval):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchImage downloads the produced image bytes via /view.
|
||||||
|
func (c *Comfy) fetchImage(ctx context.Context, filename string) ([]byte, error) {
|
||||||
|
q := url.Values{
|
||||||
|
"filename": {filename},
|
||||||
|
"type": {"output"},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/view?"+q.Encode(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, c.connError(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("comfyui /view %d: %s", resp.StatusCode, snip(body))
|
||||||
|
}
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// vramUsedMiB returns total - free VRAM on device 0 from /system_stats, or
|
||||||
|
// 0 if the endpoint isn't available. Best-effort, never an error.
|
||||||
|
func (c *Comfy) vramUsedMiB(ctx context.Context) int64 {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/system_stats", nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var s struct {
|
||||||
|
Devices []struct {
|
||||||
|
VRAMTotal int64 `json:"vram_total"`
|
||||||
|
VRAMFree int64 `json:"vram_free"`
|
||||||
|
} `json:"devices"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(s.Devices) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
used := s.Devices[0].VRAMTotal - s.Devices[0].VRAMFree
|
||||||
|
if used < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return used / (1024 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
// connError translates a Go networking error into a user-actionable message,
|
||||||
|
// pointing at the boot-whitetower script when mRock looks asleep.
|
||||||
|
func (c *Comfy) connError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msg := err.Error()
|
||||||
|
var opErr *net.OpError
|
||||||
|
asOp := errors.As(err, &opErr)
|
||||||
|
switch {
|
||||||
|
case asOp,
|
||||||
|
strings.Contains(msg, "connection refused"),
|
||||||
|
strings.Contains(msg, "no such host"),
|
||||||
|
strings.Contains(msg, "no route to host"),
|
||||||
|
strings.Contains(msg, "network is unreachable"),
|
||||||
|
strings.Contains(msg, "i/o timeout"):
|
||||||
|
return fmt.Errorf("comfyui at %s unreachable (%v) — if mRock is asleep, run: boot-whitetower mrock", c.base, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("comfyui at %s: %w", c.base, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifyBadRequest interprets a 4xx body. Some ComfyUI builds use 400 for
|
||||||
|
// workflow-validation failures and put the diagnostics in node_errors; older
|
||||||
|
// builds use 200 + node_errors. This handles the 4xx flavour.
|
||||||
|
func (c *Comfy) classifyBadRequest(status int, body []byte) error {
|
||||||
|
if hint, ok := missingModelHint(body, c.model); ok {
|
||||||
|
return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow,
|
||||||
|
// node-IDs matching the upstream "flux-schnell" template so anyone debugging
|
||||||
|
// in the ComfyUI UI sees a familiar shape.
|
||||||
|
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"6": map[string]any{
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"text": prompt,
|
||||||
|
"clip": []any{"11", 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"8": map[string]any{
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"samples": []any{"31", 0},
|
||||||
|
"vae": []any{"10", 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"9": map[string]any{
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"filename_prefix": "imagen",
|
||||||
|
"images": []any{"8", 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"10": map[string]any{
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"inputs": map[string]any{"vae_name": c.vae},
|
||||||
|
},
|
||||||
|
"11": map[string]any{
|
||||||
|
"class_type": "DualCLIPLoader",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"clip_name1": c.clipT5,
|
||||||
|
"clip_name2": c.clipL,
|
||||||
|
"type": "flux",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"12": map[string]any{
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"unet_name": c.model,
|
||||||
|
"weight_dtype": c.dtype,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"13": map[string]any{
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"text": negative,
|
||||||
|
"clip": []any{"11", 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"27": map[string]any{
|
||||||
|
"class_type": "EmptySD3LatentImage",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"width": w,
|
||||||
|
"height": h,
|
||||||
|
"batch_size": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"30": map[string]any{
|
||||||
|
"class_type": "ModelSamplingFlux",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"model": []any{"12", 0},
|
||||||
|
"max_shift": 1.15,
|
||||||
|
"base_shift": 0.5,
|
||||||
|
"width": w,
|
||||||
|
"height": h,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"31": map[string]any{
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": map[string]any{
|
||||||
|
"model": []any{"30", 0},
|
||||||
|
"seed": seed,
|
||||||
|
"steps": steps,
|
||||||
|
"cfg": 1.0,
|
||||||
|
"sampler_name": sampler,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"denoise": 1.0,
|
||||||
|
"positive": []any{"6", 0},
|
||||||
|
"negative": []any{"13", 0},
|
||||||
|
"latent_image": []any{"27", 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a
|
||||||
|
// validation failure and stuffs node_errors in the body — this function
|
||||||
|
// turns that into the same user-facing error as a 4xx with the same body.
|
||||||
|
func parsePromptID(body []byte, model string) (string, error) {
|
||||||
|
var resp struct {
|
||||||
|
PromptID string `json:"prompt_id"`
|
||||||
|
NodeErrors map[string]any `json:"node_errors"`
|
||||||
|
Error json.RawMessage `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return "", fmt.Errorf("comfyui /prompt: parse response: %w (body: %s)", err, snip(body))
|
||||||
|
}
|
||||||
|
if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 {
|
||||||
|
if hint, ok := missingModelHint(body, model); ok {
|
||||||
|
return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body))
|
||||||
|
}
|
||||||
|
if resp.PromptID == "" {
|
||||||
|
return "", fmt.Errorf("comfyui /prompt: empty prompt_id (body: %s)", snip(body))
|
||||||
|
}
|
||||||
|
return resp.PromptID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHistory inspects a /history/{id} body and returns either the produced
|
||||||
|
// filename + done=true, or done=false to signal "keep polling".
|
||||||
|
func parseHistory(body []byte, promptID string) (string, bool, error) {
|
||||||
|
var entries map[string]struct {
|
||||||
|
Status struct {
|
||||||
|
Completed bool `json:"completed"`
|
||||||
|
StatusStr string `json:"status_str"`
|
||||||
|
} `json:"status"`
|
||||||
|
Outputs map[string]struct {
|
||||||
|
Images []struct {
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Subfolder string `json:"subfolder"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"images"`
|
||||||
|
} `json:"outputs"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &entries); err != nil {
|
||||||
|
return "", false, fmt.Errorf("comfyui /history: parse: %w (body: %s)", err, snip(body))
|
||||||
|
}
|
||||||
|
e, ok := entries[promptID]
|
||||||
|
if !ok {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
if e.Status.StatusStr == "error" {
|
||||||
|
return "", false, fmt.Errorf("comfyui prompt %s errored: %s", promptID, snip(body))
|
||||||
|
}
|
||||||
|
if !e.Status.Completed {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
for _, out := range e.Outputs {
|
||||||
|
if len(out.Images) > 0 {
|
||||||
|
return out.Images[0].Filename, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", true, fmt.Errorf("comfyui prompt %s completed but produced no images", promptID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// missingModelHint returns a user-actionable message when the response body
|
||||||
|
// indicates the configured unet model isn't loaded on the server. ComfyUI
|
||||||
|
// uses both the human-readable "Value not in list" message and the enum
|
||||||
|
// "value_not_in_list" type — match either.
|
||||||
|
func missingModelHint(body []byte, model string) (string, bool) {
|
||||||
|
s := string(body)
|
||||||
|
hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list")
|
||||||
|
if hasMarker && strings.Contains(s, "unet_name") {
|
||||||
|
return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func cryptoSeed() int64 {
|
||||||
|
var b [8]byte
|
||||||
|
if _, err := rand.Read(b[:]); err != nil {
|
||||||
|
return time.Now().UnixNano()
|
||||||
|
}
|
||||||
|
return int64(binary.BigEndian.Uint64(b[:]) >> 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randClientID() string {
|
||||||
|
var b [8]byte
|
||||||
|
_, _ = rand.Read(b[:])
|
||||||
|
return fmt.Sprintf("imagen-%x", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getString(m map[string]any, k, def string) string {
|
||||||
|
if v, ok := m[k].(string); ok && v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInt(m map[string]any, k string, def int) int {
|
||||||
|
if v, ok := m[k]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
func orDefaultInt(v, def int) int {
|
||||||
|
if v == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func snip(b []byte) string {
|
||||||
|
const max = 500
|
||||||
|
s := strings.TrimSpace(string(b))
|
||||||
|
if len(s) > max {
|
||||||
|
s = s[:max] + "..."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Register(ComfyType, NewComfy)
|
||||||
|
}
|
||||||
494
internal/backend/comfyui_test.go
Normal file
494
internal/backend/comfyui_test.go
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeComfy is a programmable mock of the ComfyUI HTTP API. Tests configure
|
||||||
|
// its behaviour by adjusting the public fields before issuing the request.
|
||||||
|
type fakeComfy struct {
|
||||||
|
t *testing.T
|
||||||
|
|
||||||
|
// /prompt
|
||||||
|
promptStatus int
|
||||||
|
promptBody []byte
|
||||||
|
promptCalls atomic.Int32
|
||||||
|
failPromptUntil int32 // first N /prompt calls return promptFailStatus
|
||||||
|
promptFailStatus int
|
||||||
|
promptFailBody []byte
|
||||||
|
|
||||||
|
// /history — start by returning {} (no entry), flip to completed once
|
||||||
|
// historyReadyAfter polls have happened.
|
||||||
|
historyReadyAfter int32
|
||||||
|
historyCalls atomic.Int32
|
||||||
|
historyError bool
|
||||||
|
|
||||||
|
// /view
|
||||||
|
viewStatus int
|
||||||
|
viewBody []byte
|
||||||
|
viewType string
|
||||||
|
|
||||||
|
// /system_stats
|
||||||
|
statsTotal int64
|
||||||
|
statsFree int64
|
||||||
|
|
||||||
|
server *httptest.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeComfy) handler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/prompt" && r.Method == http.MethodPost:
|
||||||
|
n := f.promptCalls.Add(1)
|
||||||
|
if n <= int32(f.failPromptUntil) {
|
||||||
|
w.WriteHeader(f.promptFailStatus)
|
||||||
|
_, _ = w.Write(f.promptFailBody)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(f.promptStatus)
|
||||||
|
_, _ = w.Write(f.promptBody)
|
||||||
|
case strings.HasPrefix(r.URL.Path, "/history/") && r.Method == http.MethodGet:
|
||||||
|
n := f.historyCalls.Add(1)
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/history/")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if f.historyError {
|
||||||
|
_, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":false,"status_str":"error"},"outputs":{}}}`, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n <= f.historyReadyAfter {
|
||||||
|
_, _ = w.Write([]byte(`{}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(w,
|
||||||
|
`{"%s":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`,
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
case r.URL.Path == "/view" && r.Method == http.MethodGet:
|
||||||
|
ct := f.viewType
|
||||||
|
if ct == "" {
|
||||||
|
ct = "image/png"
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", ct)
|
||||||
|
w.WriteHeader(f.viewStatus)
|
||||||
|
_, _ = w.Write(f.viewBody)
|
||||||
|
case r.URL.Path == "/system_stats" && r.Method == http.MethodGet:
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
body := map[string]any{
|
||||||
|
"system": map[string]any{},
|
||||||
|
"devices": []map[string]any{
|
||||||
|
{"vram_total": f.statsTotal, "vram_free": f.statsFree},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(body)
|
||||||
|
default:
|
||||||
|
f.t.Errorf("fakeComfy: unexpected request %s %s", r.Method, r.URL.Path)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeComfy) start() {
|
||||||
|
f.server = httptest.NewServer(f.handler())
|
||||||
|
f.t.Cleanup(f.server.Close)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFakeComfy spins up a fakeComfy with happy-path defaults.
|
||||||
|
func newFakeComfy(t *testing.T) *fakeComfy {
|
||||||
|
t.Helper()
|
||||||
|
f := &fakeComfy{
|
||||||
|
t: t,
|
||||||
|
promptStatus: http.StatusOK,
|
||||||
|
promptBody: []byte(`{"prompt_id":"pid-abc","number":1,"node_errors":{}}`),
|
||||||
|
viewStatus: http.StatusOK,
|
||||||
|
viewBody: mustPNG(t, 16, 16),
|
||||||
|
statsTotal: 16 * 1024 * 1024 * 1024,
|
||||||
|
statsFree: 8 * 1024 * 1024 * 1024,
|
||||||
|
}
|
||||||
|
f.start()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// newComfy returns a Comfy pointed at f, with poll interval squashed for fast
|
||||||
|
// tests and deterministic seed/client_id.
|
||||||
|
func newComfy(t *testing.T, f *fakeComfy) *Comfy {
|
||||||
|
t.Helper()
|
||||||
|
be, err := NewComfy("flux-test", map[string]any{
|
||||||
|
"base_url": f.server.URL,
|
||||||
|
"model": "flux1-schnell.safetensors",
|
||||||
|
"default_steps": 4,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewComfy: %v", err)
|
||||||
|
}
|
||||||
|
c := be.(*Comfy)
|
||||||
|
c.pollInterval = time.Millisecond
|
||||||
|
c.pollTimeout = 5 * time.Second
|
||||||
|
c.randSeed = func() int64 { return 42 }
|
||||||
|
c.clientIDFn = func() string { return "imagen-test" }
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustPNG(t *testing.T, w, h int) []byte {
|
||||||
|
t.Helper()
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||||
|
for y := range h {
|
||||||
|
for x := range w {
|
||||||
|
img.Set(x, y, color.RGBA{R: 200, G: 100, B: 50, A: 255})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := png.Encode(&buf, img); err != nil {
|
||||||
|
t.Fatalf("encode png: %v", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyConstructorRequiresBaseAndModel(t *testing.T) {
|
||||||
|
if _, err := NewComfy("x", map[string]any{}); err == nil {
|
||||||
|
t.Errorf("expected error for missing base_url")
|
||||||
|
}
|
||||||
|
if _, err := NewComfy("x", map[string]any{"base_url": "http://h:1"}); err == nil {
|
||||||
|
t.Errorf("expected error for missing model")
|
||||||
|
}
|
||||||
|
if _, err := NewComfy("", map[string]any{"base_url": "http://h:1", "model": "m"}); err == nil {
|
||||||
|
t.Errorf("expected error for empty instance name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyHappyPath(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.historyReadyAfter = 2 // exercise the polling loop
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
res, err := c.Generate(context.Background(), Request{
|
||||||
|
Prompt: "a small fishbowl with a cat",
|
||||||
|
Width: 512,
|
||||||
|
Height: 512,
|
||||||
|
Steps: 4,
|
||||||
|
Seed: 1234567,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate: %v", err)
|
||||||
|
}
|
||||||
|
defer res.ImageReader.Close()
|
||||||
|
|
||||||
|
if res.MimeType != "image/png" {
|
||||||
|
t.Errorf("mime = %q", res.MimeType)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(res.ImageReader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read body: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(body, f.viewBody) {
|
||||||
|
t.Errorf("image body did not round-trip")
|
||||||
|
}
|
||||||
|
|
||||||
|
if seed, _ := res.Metadata["seed"].(int64); seed != 1234567 {
|
||||||
|
t.Errorf("metadata seed = %v", res.Metadata["seed"])
|
||||||
|
}
|
||||||
|
if model, _ := res.Metadata["model"].(string); model != "flux1-schnell.safetensors" {
|
||||||
|
t.Errorf("metadata model = %v", res.Metadata["model"])
|
||||||
|
}
|
||||||
|
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
|
||||||
|
t.Errorf("metadata steps = %v", res.Metadata["steps"])
|
||||||
|
}
|
||||||
|
if pid, _ := res.Metadata["prompt_id"].(string); pid != "pid-abc" {
|
||||||
|
t.Errorf("metadata prompt_id = %v", res.Metadata["prompt_id"])
|
||||||
|
}
|
||||||
|
if _, ok := res.Metadata["latency_ms"]; !ok {
|
||||||
|
t.Errorf("metadata missing latency_ms")
|
||||||
|
}
|
||||||
|
// vram_used_mib is best-effort but should be present given our mock stats
|
||||||
|
if vram, _ := res.Metadata["vram_used_mib"].(int64); vram != 8192 {
|
||||||
|
t.Errorf("metadata vram_used_mib = %v, want 8192", res.Metadata["vram_used_mib"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := f.historyCalls.Load(); got < 3 {
|
||||||
|
t.Errorf("expected at least 3 /history polls, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyDefaultsAppliedWhenZero(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
res, err := c.Generate(context.Background(), Request{Prompt: "p"}) // all-zero
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate: %v", err)
|
||||||
|
}
|
||||||
|
defer res.ImageReader.Close()
|
||||||
|
_, _ = io.ReadAll(res.ImageReader)
|
||||||
|
|
||||||
|
if w, _ := res.Metadata["width"].(int); w != 1024 {
|
||||||
|
t.Errorf("width default = %v", res.Metadata["width"])
|
||||||
|
}
|
||||||
|
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
|
||||||
|
t.Errorf("steps default = %v", res.Metadata["steps"])
|
||||||
|
}
|
||||||
|
if seed, _ := res.Metadata["seed"].(int64); seed != 42 {
|
||||||
|
t.Errorf("seed default (test rand hook) = %v", res.Metadata["seed"])
|
||||||
|
}
|
||||||
|
if s, _ := res.Metadata["sampler"].(string); s != "euler" {
|
||||||
|
t.Errorf("sampler default = %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyPromptRetriesOnce5xx(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.failPromptUntil = 1
|
||||||
|
f.promptFailStatus = http.StatusBadGateway
|
||||||
|
f.promptFailBody = []byte("upstream busy")
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
res, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate (with one 502 then OK): %v", err)
|
||||||
|
}
|
||||||
|
defer res.ImageReader.Close()
|
||||||
|
_, _ = io.ReadAll(res.ImageReader)
|
||||||
|
|
||||||
|
if got := f.promptCalls.Load(); got != 2 {
|
||||||
|
t.Errorf("expected exactly 2 /prompt calls (1 fail + 1 retry), got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyPromptGivesUpAfterTwo5xx(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.failPromptUntil = 99 // every call fails
|
||||||
|
f.promptFailStatus = http.StatusServiceUnavailable
|
||||||
|
f.promptFailBody = []byte("nope")
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error after sustained 503s")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "503") {
|
||||||
|
t.Errorf("expected error to mention 503, got %v", err)
|
||||||
|
}
|
||||||
|
if got := f.promptCalls.Load(); got != 2 {
|
||||||
|
t.Errorf("expected exactly 2 /prompt calls (no further retries), got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyPromptDoesNotRetryOn4xx(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.failPromptUntil = 99
|
||||||
|
f.promptFailStatus = http.StatusBadRequest
|
||||||
|
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation"},"node_errors":{"some":"thing"}}`)
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 400")
|
||||||
|
}
|
||||||
|
if got := f.promptCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("expected exactly 1 /prompt call (no retry on 4xx), got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.failPromptUntil = 99
|
||||||
|
f.promptFailStatus = http.StatusBadRequest
|
||||||
|
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation","message":"Prompt outputs failed validation"},"node_errors":{"12":{"errors":[{"type":"value_not_in_list","message":"Value not in list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
msg := err.Error()
|
||||||
|
if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") {
|
||||||
|
t.Errorf("error should point at the setup doc, got %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, "flux1-schnell.safetensors") {
|
||||||
|
t.Errorf("error should name the missing model, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) {
|
||||||
|
// Older ComfyUI builds 200 a workflow-validation failure.
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.promptStatus = http.StatusOK
|
||||||
|
f.promptBody = []byte(`{"prompt_id":"","node_errors":{"12":{"errors":[{"type":"value_not_in_list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for node_errors on 200")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") {
|
||||||
|
t.Errorf("error should point at the setup doc, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyHistoryErrorSurfaced(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.historyError = true
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when history reports execution error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "errored") {
|
||||||
|
t.Errorf("expected 'errored' in message, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyViewFailureSurfaced(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.viewStatus = http.StatusNotFound
|
||||||
|
f.viewBody = []byte("nope")
|
||||||
|
c := newComfy(t, f)
|
||||||
|
|
||||||
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when /view 404s")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "404") {
|
||||||
|
t.Errorf("expected status code in error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyUnreachableHostMentionsBootHelper(t *testing.T) {
|
||||||
|
be, err := NewComfy("flux-test", map[string]any{
|
||||||
|
"base_url": "http://127.0.0.1:1", // closed port; connection refused
|
||||||
|
"model": "flux1-schnell.safetensors",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewComfy: %v", err)
|
||||||
|
}
|
||||||
|
c := be.(*Comfy)
|
||||||
|
c.httpClient = &http.Client{Timeout: 500 * time.Millisecond}
|
||||||
|
|
||||||
|
_, err = c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unreachable host")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "boot-whitetower mrock") {
|
||||||
|
t.Errorf("expected boot-helper hint, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyContextCancelStopsPolling(t *testing.T) {
|
||||||
|
f := newFakeComfy(t)
|
||||||
|
f.historyReadyAfter = 1_000_000 // never finishes
|
||||||
|
c := newComfy(t, f)
|
||||||
|
c.pollInterval = 5 * time.Millisecond
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := c.Generate(ctx, Request{Prompt: "p", Width: 64, Height: 64})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected ctx.Err()")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||||
|
t.Errorf("expected deadline exceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyWorkflowReflectsRequest(t *testing.T) {
|
||||||
|
// Capture the workflow body to assert KSampler + EmptyLatentImage values.
|
||||||
|
var captured []byte
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/prompt":
|
||||||
|
captured, _ = io.ReadAll(r.Body)
|
||||||
|
_, _ = w.Write([]byte(`{"prompt_id":"pid","number":1,"node_errors":{}}`))
|
||||||
|
case "/history/pid":
|
||||||
|
_, _ = w.Write([]byte(`{"pid":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`))
|
||||||
|
case "/view":
|
||||||
|
_, _ = w.Write(mustPNG(t, 8, 8))
|
||||||
|
case "/system_stats":
|
||||||
|
_, _ = w.Write([]byte(`{"devices":[{"vram_total":1,"vram_free":1}]}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
be, err := NewComfy("flux-test", map[string]any{
|
||||||
|
"base_url": srv.URL,
|
||||||
|
"model": "custom.safetensors",
|
||||||
|
"default_steps": 7,
|
||||||
|
"default_sampler": "dpmpp_2m",
|
||||||
|
"default_scheduler": "karras",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewComfy: %v", err)
|
||||||
|
}
|
||||||
|
c := be.(*Comfy)
|
||||||
|
c.pollInterval = time.Millisecond
|
||||||
|
c.randSeed = func() int64 { return 9999 }
|
||||||
|
|
||||||
|
res, err := c.Generate(context.Background(), Request{
|
||||||
|
Prompt: "a cat",
|
||||||
|
NegativePrompt: "blurry",
|
||||||
|
Width: 768,
|
||||||
|
Height: 512,
|
||||||
|
Steps: 11,
|
||||||
|
Seed: 555,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate: %v", err)
|
||||||
|
}
|
||||||
|
res.ImageReader.Close()
|
||||||
|
|
||||||
|
var sent struct {
|
||||||
|
Prompt map[string]map[string]any `json:"prompt"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(captured, &sent); err != nil {
|
||||||
|
t.Fatalf("unmarshal captured: %v", err)
|
||||||
|
}
|
||||||
|
ks := sent.Prompt["31"]["inputs"].(map[string]any)
|
||||||
|
if ks["seed"].(float64) != 555 {
|
||||||
|
t.Errorf("KSampler seed = %v, want 555", ks["seed"])
|
||||||
|
}
|
||||||
|
if ks["steps"].(float64) != 11 {
|
||||||
|
t.Errorf("KSampler steps = %v, want 11", ks["steps"])
|
||||||
|
}
|
||||||
|
if ks["sampler_name"].(string) != "dpmpp_2m" {
|
||||||
|
t.Errorf("sampler_name = %v", ks["sampler_name"])
|
||||||
|
}
|
||||||
|
if ks["scheduler"].(string) != "karras" {
|
||||||
|
t.Errorf("scheduler = %v", ks["scheduler"])
|
||||||
|
}
|
||||||
|
latent := sent.Prompt["27"]["inputs"].(map[string]any)
|
||||||
|
if latent["width"].(float64) != 768 || latent["height"].(float64) != 512 {
|
||||||
|
t.Errorf("EmptySD3LatentImage size = %vx%v", latent["width"], latent["height"])
|
||||||
|
}
|
||||||
|
unet := sent.Prompt["12"]["inputs"].(map[string]any)
|
||||||
|
if unet["unet_name"].(string) != "custom.safetensors" {
|
||||||
|
t.Errorf("unet_name = %v", unet["unet_name"])
|
||||||
|
}
|
||||||
|
neg := sent.Prompt["13"]["inputs"].(map[string]any)
|
||||||
|
if neg["text"].(string) != "blurry" {
|
||||||
|
t.Errorf("negative prompt not threaded: %v", neg["text"])
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(sent.ClientID, "imagen-") && sent.ClientID == "" {
|
||||||
|
t.Errorf("client_id should be set: %q", sent.ClientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComfyTypeIsRegistered(t *testing.T) {
|
||||||
|
if !Default.Has(ComfyType) {
|
||||||
|
t.Errorf("comfyui type not registered in Default")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -95,7 +95,7 @@ const Sample = `# imagen.yaml — config for the imagen CLI.
|
|||||||
# implementing the Backend interface, registering its type name, and listing
|
# implementing the Backend interface, registering its type name, and listing
|
||||||
# an instance here.
|
# an instance here.
|
||||||
|
|
||||||
default_backend: mock
|
default_backend: flux-schnell-local
|
||||||
|
|
||||||
output:
|
output:
|
||||||
directory: ~/Pictures/imagen
|
directory: ~/Pictures/imagen
|
||||||
@@ -103,14 +103,18 @@ output:
|
|||||||
write_metadata_json: true
|
write_metadata_json: true
|
||||||
|
|
||||||
backends:
|
backends:
|
||||||
mock:
|
|
||||||
type: mock
|
|
||||||
|
|
||||||
flux-schnell-local:
|
flux-schnell-local:
|
||||||
type: comfyui
|
type: comfyui
|
||||||
base_url: http://mrock:8188
|
base_url: http://mrock:8188
|
||||||
|
# Filename of the unet checkpoint inside the ComfyUI server's
|
||||||
|
# models/unet/ directory. See docs/setup-comfyui-mrock.md.
|
||||||
model: flux1-schnell.safetensors
|
model: flux1-schnell.safetensors
|
||||||
default_steps: 4
|
default_steps: 4
|
||||||
|
default_sampler: euler
|
||||||
|
default_scheduler: simple
|
||||||
|
|
||||||
|
mock:
|
||||||
|
type: mock
|
||||||
|
|
||||||
flux-dev-replicate:
|
flux-dev-replicate:
|
||||||
type: replicate
|
type: replicate
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func TestLoadAndValidate(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Load: %v", err)
|
t.Fatalf("Load: %v", err)
|
||||||
}
|
}
|
||||||
if cfg.DefaultBackend != "mock" {
|
if cfg.DefaultBackend != "flux-schnell-local" {
|
||||||
t.Errorf("default = %q", cfg.DefaultBackend)
|
t.Errorf("default = %q", cfg.DefaultBackend)
|
||||||
}
|
}
|
||||||
mock, ok := cfg.Backends["mock"]
|
mock, ok := cfg.Backends["mock"]
|
||||||
@@ -30,9 +30,15 @@ func TestLoadAndValidate(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("flux backend missing")
|
t.Fatalf("flux backend missing")
|
||||||
}
|
}
|
||||||
|
if flux.Type != "comfyui" {
|
||||||
|
t.Errorf("flux type = %q", flux.Type)
|
||||||
|
}
|
||||||
if flux.Raw["base_url"] != "http://mrock:8188" {
|
if flux.Raw["base_url"] != "http://mrock:8188" {
|
||||||
t.Errorf("flux base_url = %v", flux.Raw["base_url"])
|
t.Errorf("flux base_url = %v", flux.Raw["base_url"])
|
||||||
}
|
}
|
||||||
|
if flux.Raw["model"] != "flux1-schnell.safetensors" {
|
||||||
|
t.Errorf("flux model = %v", flux.Raw["model"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateRejectsUnknownDefault(t *testing.T) {
|
func TestValidateRejectsUnknownDefault(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user