Compare commits
7 Commits
mai/bohr/i
...
mai/hermes
| Author | SHA1 | Date | |
|---|---|---|---|
| b282325663 | |||
| a1d0165445 | |||
| 2a8bd4313b | |||
| 4183d4c55a | |||
| 127bbf3ed5 | |||
| a24ac2826f | |||
| 20490913c1 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@
|
||||
.env.local
|
||||
/imagen
|
||||
/coverage.txt
|
||||
/.m/
|
||||
|
||||
@@ -10,6 +10,27 @@ import (
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
)
|
||||
|
||||
// instanceStatus checks adapter-specific preconditions (e.g. the
|
||||
// Replicate API token env var being set) and returns a short
|
||||
// user-facing status string.
|
||||
func instanceStatus(spec config.BackendSpec) string {
|
||||
if !backend.Default.Has(spec.Type) {
|
||||
return fmt.Sprintf("type %q not compiled in", spec.Type)
|
||||
}
|
||||
switch spec.Type {
|
||||
case backend.ReplicateType:
|
||||
envName, _ := spec.Raw["api_token_env"].(string)
|
||||
if envName == "" {
|
||||
envName = "REPLICATE_API_TOKEN"
|
||||
}
|
||||
if os.Getenv(envName) == "" {
|
||||
return fmt.Sprintf("not configured (set %s)", envName)
|
||||
}
|
||||
return "ok"
|
||||
}
|
||||
return "registered"
|
||||
}
|
||||
|
||||
func runBackends(args []string) error {
|
||||
fs := flag.NewFlagSet("backends", flag.ContinueOnError)
|
||||
var configPath string
|
||||
@@ -27,10 +48,7 @@ func runBackends(args []string) error {
|
||||
fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS")
|
||||
if cfg != nil {
|
||||
for name, spec := range cfg.Backends {
|
||||
status := "registered"
|
||||
if !backend.Default.Has(spec.Type) {
|
||||
status = fmt.Sprintf("type %q not compiled in", spec.Type)
|
||||
}
|
||||
status := instanceStatus(spec)
|
||||
marker := ""
|
||||
if name == cfg.DefaultBackend {
|
||||
marker = " (default)"
|
||||
|
||||
@@ -11,21 +11,25 @@ import (
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/output"
|
||||
"mgit.msbls.de/m/ImaGen/internal/preview"
|
||||
"mgit.msbls.de/m/ImaGen/internal/prompt"
|
||||
"mgit.msbls.de/m/ImaGen/internal/usage"
|
||||
)
|
||||
|
||||
func runGenerate(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("generate", flag.ContinueOnError)
|
||||
var (
|
||||
backendName string
|
||||
size string
|
||||
outPath string
|
||||
seed int64
|
||||
steps int
|
||||
style string
|
||||
negative string
|
||||
configPath string
|
||||
noSidecar bool
|
||||
backendName string
|
||||
size string
|
||||
outPath string
|
||||
seed int64
|
||||
steps int
|
||||
style string
|
||||
negative string
|
||||
configPath string
|
||||
noSidecar bool
|
||||
previewOn bool
|
||||
previewOff bool
|
||||
)
|
||||
fs.StringVar(&backendName, "backend", "", "backend instance name (default: config.default_backend)")
|
||||
fs.StringVar(&size, "size", "1024x1024", "WxH, e.g. 1024x1024")
|
||||
@@ -36,6 +40,8 @@ func runGenerate(ctx context.Context, args []string) error {
|
||||
fs.StringVar(&negative, "negative", "", "negative prompt (ignored by backends that don't support it)")
|
||||
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
|
||||
fs.BoolVar(&noSidecar, "no-sidecar", false, "skip the JSON sidecar even if config enables it")
|
||||
fs.BoolVar(&previewOn, "preview", false, "force tmux preview window on (errors outside $TMUX)")
|
||||
fs.BoolVar(&previewOff, "no-preview", false, "skip the tmux preview window")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), `Usage: imagen generate "<prompt>" [flags]`)
|
||||
fs.PrintDefaults()
|
||||
@@ -76,6 +82,7 @@ func runGenerate(ctx context.Context, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
attachUsageSink(be)
|
||||
|
||||
finalPrompt, err := prompt.Apply(rawPrompt, style)
|
||||
if err != nil {
|
||||
@@ -118,9 +125,70 @@ func runGenerate(ctx context.Context, args []string) error {
|
||||
if paths.SidecarPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
|
||||
}
|
||||
|
||||
if err := maybePreview(cfg, previewOn, previewOff, paths.ImagePath, rawPrompt); err != nil {
|
||||
// preview failures are warnings — the image already wrote.
|
||||
fmt.Fprintln(os.Stderr, "imagen: preview:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePreviewMode applies the precedence chain config -> env -> flag.
|
||||
// Flags win, env beats config, config beats the implicit auto default.
|
||||
func resolvePreviewMode(cfg *config.Config, flagOn, flagOff bool, env string) (preview.Mode, error) {
|
||||
mode := preview.ModeAuto
|
||||
if cfg != nil && cfg.Output.Preview != "" {
|
||||
m, err := preview.ParseMode(cfg.Output.Preview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("config output.preview: %w", err)
|
||||
}
|
||||
mode = m
|
||||
}
|
||||
if env != "" {
|
||||
m, err := preview.ParseMode(env)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("$IMAGEN_PREVIEW: %w", err)
|
||||
}
|
||||
mode = m
|
||||
}
|
||||
if flagOn && flagOff {
|
||||
return "", userErr("--preview and --no-preview are mutually exclusive")
|
||||
}
|
||||
if flagOn {
|
||||
mode = preview.ModeOn
|
||||
}
|
||||
if flagOff {
|
||||
mode = preview.ModeOff
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
// maybePreview resolves the effective preview mode and, if it says yes,
|
||||
// spawns a tmux window via tmux-img. Always non-fatal.
|
||||
func maybePreview(cfg *config.Config, flagOn, flagOff bool, imagePath, rawPrompt string) error {
|
||||
mode, err := resolvePreviewMode(cfg, flagOn, flagOff, os.Getenv("IMAGEN_PREVIEW"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decision, err := preview.Resolve(mode, os.Getenv("TMUX") != "", stdoutIsTTY())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !decision.ShouldPreview {
|
||||
return nil
|
||||
}
|
||||
spawner := &preview.Spawner{}
|
||||
return spawner.Spawn(imagePath, output.Slug(rawPrompt))
|
||||
}
|
||||
|
||||
func stdoutIsTTY() bool {
|
||||
fi, err := os.Stdout.Stat()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return fi.Mode()&os.ModeCharDevice != 0
|
||||
}
|
||||
|
||||
// splitLeadingPositional separates the positional args at the start of args
|
||||
// from the rest (which begins with the first flag). A literal "--" terminator
|
||||
// pushes everything after it into the positional list and out of flag parsing.
|
||||
@@ -153,6 +221,21 @@ func parseSize(s string) (int, int, error) {
|
||||
return w, h, nil
|
||||
}
|
||||
|
||||
// attachUsageSink wires a Supabase cost-tracking sink into the backend
|
||||
// when it accepts one and the env is configured. Adapters that record
|
||||
// usage expose a public Sink field of type backend.UsageSink.
|
||||
func attachUsageSink(be backend.Backend) {
|
||||
r, ok := be.(*backend.Replicate)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sink, ok := usage.NewSupabaseSinkFromEnv()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
r.Sink = sink
|
||||
}
|
||||
|
||||
func buildBackend(cfg *config.Config, name string) (backend.Backend, error) {
|
||||
if cfg != nil {
|
||||
spec, ok := cfg.Backends[name]
|
||||
|
||||
50
cmd/imagen/generate_test.go
Normal file
50
cmd/imagen/generate_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/preview"
|
||||
)
|
||||
|
||||
func TestResolvePreviewMode(t *testing.T) {
|
||||
type tc struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
flagOn bool
|
||||
flagOff bool
|
||||
env string
|
||||
want preview.Mode
|
||||
wantError bool
|
||||
}
|
||||
cases := []tc{
|
||||
{name: "all-empty-defaults-to-auto", want: preview.ModeAuto},
|
||||
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, want: preview.ModeOn},
|
||||
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{Preview: "off"}}, want: preview.ModeOff},
|
||||
{name: "config-auto-explicit", cfg: &config.Config{Output: config.OutputConfig{Preview: "auto"}}, want: preview.ModeAuto},
|
||||
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, env: "off", want: preview.ModeOff},
|
||||
{name: "flag-on-overrides-env-off", env: "off", flagOn: true, want: preview.ModeOn},
|
||||
{name: "flag-off-overrides-env-on", env: "on", flagOff: true, want: preview.ModeOff},
|
||||
{name: "flag-off-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, flagOff: true, want: preview.ModeOff},
|
||||
{name: "both-flags-error", flagOn: true, flagOff: true, wantError: true},
|
||||
{name: "bad-env-errors", env: "yes", wantError: true},
|
||||
{name: "bad-config-errors", cfg: &config.Config{Output: config.OutputConfig{Preview: "yes"}}, wantError: true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got, err := resolvePreviewMode(c.cfg, c.flagOn, c.flagOff, c.env)
|
||||
if c.wantError {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got mode %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != c.want {
|
||||
t.Errorf("mode = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
_ "mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
)
|
||||
|
||||
const usage = `imagen — model-agnostic image generation
|
||||
const helpText = `imagen — model-agnostic image generation
|
||||
|
||||
Usage:
|
||||
imagen generate <prompt> [flags] generate one image
|
||||
@@ -22,6 +22,7 @@ Usage:
|
||||
imagen config init print a sample imagen.yaml on stdout
|
||||
imagen config validate validate the active config
|
||||
imagen serve [--addr :8080] (stub) start the HTTP server
|
||||
imagen usage [--since DATE] show cost-tracking rows
|
||||
imagen version print version
|
||||
imagen help show this help
|
||||
|
||||
@@ -33,7 +34,7 @@ var Version = "dev"
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprint(os.Stderr, usage)
|
||||
fmt.Fprint(os.Stderr, helpText)
|
||||
os.Exit(2)
|
||||
}
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
@@ -50,12 +51,14 @@ func main() {
|
||||
err = runConfig(args)
|
||||
case "serve":
|
||||
err = runServe(args)
|
||||
case "usage":
|
||||
err = runUsage(ctx, args)
|
||||
case "version", "-v", "--version":
|
||||
fmt.Println(Version)
|
||||
case "help", "-h", "--help":
|
||||
fmt.Print(usage)
|
||||
fmt.Print(helpText)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage)
|
||||
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText)
|
||||
os.Exit(2)
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
189
cmd/imagen/usage.go
Normal file
189
cmd/imagen/usage.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/usage"
|
||||
)
|
||||
|
||||
// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage
|
||||
// via Supabase REST and prints a tab-aligned table grouped by week +
|
||||
// backend + model + caller, with totals at the bottom.
|
||||
func runUsage(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("usage", flag.ContinueOnError)
|
||||
var (
|
||||
since string
|
||||
raw bool
|
||||
)
|
||||
fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date")
|
||||
fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sinceT time.Time
|
||||
if since != "" {
|
||||
t, err := time.Parse("2006-01-02", since)
|
||||
if err != nil {
|
||||
return userErr("--since must be YYYY-MM-DD: %v", err)
|
||||
}
|
||||
sinceT = t
|
||||
}
|
||||
|
||||
sink, ok := usage.NewSupabaseSinkFromEnv()
|
||||
if !ok {
|
||||
return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage")
|
||||
}
|
||||
rows, err := sink.Query(ctx, sinceT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if raw {
|
||||
printRawRows(rows)
|
||||
return nil
|
||||
}
|
||||
printGroupedRows(rows)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRawRows(rows []usage.Row) {
|
||||
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD")
|
||||
var totalCost float64
|
||||
for _, r := range rows {
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
r.CreatedAt.Local().Format("2006-01-02 15:04"),
|
||||
r.Backend,
|
||||
r.Model,
|
||||
derefString(r.Caller),
|
||||
intOrDash(r.LatencyMs),
|
||||
costOrDash(r.CostUSDEstimate),
|
||||
)
|
||||
if r.CostUSDEstimate != nil {
|
||||
totalCost += *r.CostUSDEstimate
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost)
|
||||
_ = tw.Flush()
|
||||
}
|
||||
|
||||
type group struct {
|
||||
week string
|
||||
backend string
|
||||
model string
|
||||
caller string
|
||||
count int
|
||||
cost float64
|
||||
costSet bool
|
||||
}
|
||||
|
||||
type groupKey struct {
|
||||
week, backend, model, caller string
|
||||
}
|
||||
|
||||
func printGroupedRows(rows []usage.Row) {
|
||||
groups := map[groupKey]*group{}
|
||||
for _, r := range rows {
|
||||
caller := derefString(r.Caller)
|
||||
k := groupKey{
|
||||
week: weekStart(r.CreatedAt).Format("2006-01-02"),
|
||||
backend: r.Backend,
|
||||
model: r.Model,
|
||||
caller: caller,
|
||||
}
|
||||
g, ok := groups[k]
|
||||
if !ok {
|
||||
g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller}
|
||||
groups[k] = g
|
||||
}
|
||||
g.count++
|
||||
if r.CostUSDEstimate != nil {
|
||||
g.cost += *r.CostUSDEstimate
|
||||
g.costSet = true
|
||||
}
|
||||
}
|
||||
|
||||
keys := make([]groupKey, 0, len(groups))
|
||||
for k := range groups {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
if keys[i].week != keys[j].week {
|
||||
return keys[i].week > keys[j].week // newest first
|
||||
}
|
||||
if keys[i].backend != keys[j].backend {
|
||||
return keys[i].backend < keys[j].backend
|
||||
}
|
||||
if keys[i].model != keys[j].model {
|
||||
return keys[i].model < keys[j].model
|
||||
}
|
||||
return keys[i].caller < keys[j].caller
|
||||
})
|
||||
|
||||
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD")
|
||||
var totalCount int
|
||||
var totalCost float64
|
||||
for _, k := range keys {
|
||||
g := groups[k]
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n",
|
||||
g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet),
|
||||
)
|
||||
totalCount += g.count
|
||||
totalCost += g.cost
|
||||
}
|
||||
fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost)
|
||||
_ = tw.Flush()
|
||||
}
|
||||
|
||||
// weekStart returns the Monday of the week containing t (UTC).
|
||||
func weekStart(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
wd := int(t.Weekday())
|
||||
if wd == 0 {
|
||||
wd = 7 // shift Sunday to end-of-week
|
||||
}
|
||||
delta := time.Duration(wd-1) * -24 * time.Hour
|
||||
d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
return d.Add(delta)
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
func intOrDash(p *int) string {
|
||||
if p == nil {
|
||||
return "-"
|
||||
}
|
||||
return fmt.Sprintf("%d", *p)
|
||||
}
|
||||
|
||||
func costOrDash(p *float64) string {
|
||||
if p == nil {
|
||||
return "-"
|
||||
}
|
||||
return fmt.Sprintf("%.4f", *p)
|
||||
}
|
||||
|
||||
func costStr(v float64, set bool) string {
|
||||
if !set {
|
||||
return "-"
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%.4f", v))
|
||||
}
|
||||
@@ -15,6 +15,7 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
|
||||
│ internal/prompt │ style preset → prompt suffix
|
||||
│ internal/output │ filename templating, sidecar
|
||||
│ internal/config │ YAML loader, validation
|
||||
│ internal/preview │ tmux-img window spawner
|
||||
└──────────┬────────────┘
|
||||
│
|
||||
┌──────────▼────────────┐
|
||||
|
||||
181
docs/setup-comfyui-mrock.md
Normal file
181
docs/setup-comfyui-mrock.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# ComfyUI on mRock — install + ops
|
||||
|
||||
ImaGen's `flux-schnell-local` backend talks to ComfyUI on mRock at
|
||||
`http://mrock:8188` (Tailscale-internal). This document is the reproducible
|
||||
install path from a clean mRock state.
|
||||
|
||||
mRock runs Arch Linux + systemd with an NVIDIA RTX 4070 Ti SUPER (16 GB
|
||||
VRAM). Ollama is already a native systemd service, so ComfyUI follows the
|
||||
same pattern (native Python venv + systemd unit) instead of Docker — Docker
|
||||
on mRock has no `nvidia` runtime configured, and adding one is more invasive
|
||||
than another systemd unit.
|
||||
|
||||
## Prerequisites on mRock
|
||||
|
||||
- Python via `uv` (already installed).
|
||||
- NVIDIA driver new enough for CUDA 12.4. `nvidia-smi --query-gpu=driver_version`
|
||||
should show >= 550. Driver 595 is what mRock has today.
|
||||
- ~35 GB free on `/home` for the model files.
|
||||
- `ollama.service` running on port 11434 — coexistence notes below.
|
||||
|
||||
## 1. Clone ComfyUI + Python venv
|
||||
|
||||
```bash
|
||||
mkdir -p ~/dev && cd ~/dev
|
||||
git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git comfyui
|
||||
cd comfyui
|
||||
uv venv --python 3.12 .venv
|
||||
source .venv/bin/activate.fish
|
||||
|
||||
# PyTorch CUDA 12.4 wheels — match the system driver
|
||||
uv pip install --no-cache torch torchvision torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
uv pip install --no-cache -r requirements.txt
|
||||
```
|
||||
|
||||
Verify CUDA is wired up:
|
||||
|
||||
```bash
|
||||
.venv/bin/python -c \
|
||||
"import torch; print(torch.__version__, torch.cuda.is_available(), torch.cuda.get_device_name(0))"
|
||||
# expected: 2.6.0+cu124 True NVIDIA GeForce RTX 4070 Ti SUPER
|
||||
```
|
||||
|
||||
## 2. Models — FLUX.1 schnell
|
||||
|
||||
The Black-Forest-Labs primary repo (`black-forest-labs/FLUX.1-schnell`) is
|
||||
**gated** — `curl` against it without an HF token returns HTTP 401. We pull
|
||||
the weights from ungated mirrors of the same Apache-2.0 release.
|
||||
|
||||
| File | Where it goes | Source |
|
||||
|------|---------------|--------|
|
||||
| `flux1-schnell.safetensors` (~23.8 GB, fp16) | `models/unet/` | `Comfy-Org/flux1-schnell` |
|
||||
| `ae.safetensors` (~335 MB) | `models/vae/` | `sirorable/flux-ae-vae` |
|
||||
| `clip_l.safetensors` (~246 MB) | `models/clip/` | `comfyanonymous/flux_text_encoders` |
|
||||
| `t5xxl_fp8_e4m3fn.safetensors` (~4.9 GB) | `models/clip/` | `comfyanonymous/flux_text_encoders` |
|
||||
|
||||
```bash
|
||||
cd ~/dev/comfyui/models
|
||||
|
||||
curl -L -o unet/flux1-schnell.safetensors \
|
||||
https://huggingface.co/Comfy-Org/flux1-schnell/resolve/main/flux1-schnell.safetensors
|
||||
curl -L -o vae/ae.safetensors \
|
||||
https://huggingface.co/sirorable/flux-ae-vae/resolve/main/ae.safetensors
|
||||
curl -L -o clip/clip_l.safetensors \
|
||||
https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors
|
||||
curl -L -o clip/t5xxl_fp8_e4m3fn.safetensors \
|
||||
https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors
|
||||
```
|
||||
|
||||
If a new HF token is configured later (`~/.cache/huggingface/token`), the
|
||||
official `black-forest-labs/FLUX.1-schnell` URL is byte-identical and can be
|
||||
swapped in.
|
||||
|
||||
## 3. systemd unit
|
||||
|
||||
Drop `/etc/systemd/system/comfyui.service`:
|
||||
|
||||
```ini
|
||||
[Unit]
|
||||
Description=ComfyUI image generation server
|
||||
Documentation=https://github.com/comfyanonymous/ComfyUI
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=m
|
||||
Group=m
|
||||
WorkingDirectory=/home/m/dev/comfyui
|
||||
ExecStart=/home/m/dev/comfyui/.venv/bin/python /home/m/dev/comfyui/main.py \
|
||||
--listen 0.0.0.0 --port 8188 \
|
||||
--output-directory /home/m/dev/comfyui/output \
|
||||
--temp-directory /home/m/dev/comfyui/temp
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
TimeoutStopSec=30
|
||||
NoNewPrivileges=true
|
||||
PrivateTmp=true
|
||||
LimitNOFILE=65535
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
Then:
|
||||
|
||||
```bash
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl enable --now comfyui.service
|
||||
systemctl status comfyui.service
|
||||
```
|
||||
|
||||
The service binds `0.0.0.0:8188`. Tailscale's wireguard fence is the only
|
||||
auth — do **not** expose port 8188 to the public internet.
|
||||
|
||||
## 4. Health check
|
||||
|
||||
```bash
|
||||
curl -fsS --max-time 5 http://mrock:8188/system_stats | jq '.devices[0]'
|
||||
# expected: name "cuda:0 NVIDIA GeForce RTX 4070 Ti SUPER ...", vram_total ~16 GB
|
||||
```
|
||||
|
||||
`imagen backends` (from a host with the ImaGen CLI installed) should also
|
||||
report `flux-schnell-local: ok`.
|
||||
|
||||
## 5. VRAM coexistence with Ollama
|
||||
|
||||
mRock has 16 GB VRAM total. Ollama parks ~8 GB resident for its current
|
||||
model. FLUX schnell at fp16 weights with `weight_dtype=fp8_e4m3fn` (the
|
||||
default the adapter requests) needs roughly 10–12 GB peak for a 1024×1024
|
||||
generation, so concurrent Ollama + FLUX on mRock will OOM.
|
||||
|
||||
Two practical options:
|
||||
|
||||
- **Stop Ollama before generating** — `sudo systemctl stop ollama` frees
|
||||
the GPU, run the generation, `sudo systemctl start ollama` afterwards.
|
||||
Adequate while we don't have many concurrent users.
|
||||
- **Move Ollama off mRock** — when ImaGen is in regular use, push Ollama to
|
||||
another host so the GPU is dedicated. Tracked separately.
|
||||
|
||||
Both decisions live with whoever operates the box; the adapter does not try
|
||||
to manage Ollama.
|
||||
|
||||
## 6. Smoke test (direct, without the imagen CLI)
|
||||
|
||||
```bash
|
||||
# 1) Submit a workflow
|
||||
curl -fsS --max-time 30 -X POST -H 'Content-Type: application/json' \
|
||||
-d @flux-schnell-workflow.json \
|
||||
http://mrock:8188/prompt
|
||||
# returns: {"prompt_id": "...", "number": ..., "node_errors": {}}
|
||||
|
||||
# 2) Poll history until the prompt completes
|
||||
PID=... # from above
|
||||
until curl -fsS http://mrock:8188/history/$PID | jq -e ".\"$PID\".status.completed == true" >/dev/null; do
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# 3) Pull the image
|
||||
NAME=$(curl -fsS http://mrock:8188/history/$PID \
|
||||
| jq -r ".\"$PID\".outputs[\"9\"].images[0].filename")
|
||||
curl -fsS "http://mrock:8188/view?filename=$NAME&type=output" -o /tmp/cat.png
|
||||
file /tmp/cat.png # PNG image data, 1024 x 1024
|
||||
```
|
||||
|
||||
The full ImaGen smoke test is in [usage.md](usage.md) once the Go adapter
|
||||
ships.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **`vram_free` < 6 GB in `/system_stats`**: another GPU process is holding
|
||||
memory. Usually Ollama (`sudo systemctl stop ollama`).
|
||||
- **Workflow returns `node_errors` with `Required input is missing` for
|
||||
CLIPLoader**: text encoder filenames don't match step 2 — check that
|
||||
`clip_l.safetensors` and `t5xxl_fp8_e4m3fn.safetensors` are in
|
||||
`models/clip/`, not `models/text_encoders/`.
|
||||
- **`Access to model … is restricted`** during a model pull: the script is
|
||||
hitting a gated mirror. Use the ungated URLs from step 2.
|
||||
- **Service won't start**: check `journalctl -u comfyui --since '5 min ago'`.
|
||||
Common cause is a stale `pip` install — re-run step 1.
|
||||
@@ -24,8 +24,28 @@ imagen version print version
|
||||
| `--negative` | empty | Negative prompt (ignored by some adapters) |
|
||||
| `--output` | empty (= use naming template) | Explicit path |
|
||||
| `--no-sidecar` | `false` | Skip the JSON sidecar even if config enables it |
|
||||
| `--preview` | (auto) | Force open a tmux preview window via `tmux-img` |
|
||||
| `--no-preview` | (auto) | Suppress the preview window (use for batch / CI callers) |
|
||||
| `--config` | `~/.config/imagen.yaml` | Override config path |
|
||||
|
||||
### Preview window
|
||||
|
||||
After a successful generate, imagen optionally opens a sibling tmux window
|
||||
named `img:<slug>` running `tmux-img --hold <path>`. The new window is
|
||||
spawned in the background (`tmux new-window -d`) so the generating pane
|
||||
keeps focus and its terminal output.
|
||||
|
||||
Resolution order is **config → `$IMAGEN_PREVIEW` → flag** (later wins):
|
||||
|
||||
- `output.preview` in `imagen.yaml`: `auto` (default) | `on` | `off`
|
||||
- `IMAGEN_PREVIEW=auto|on|off` overrides config
|
||||
- `--preview` / `--no-preview` override env
|
||||
|
||||
`auto` previews iff stdout is a TTY *and* `$TMUX` is set. `on` previews
|
||||
unconditionally and errors outside a tmux session. `off` never previews.
|
||||
|
||||
Preview failures are non-fatal — the image already wrote.
|
||||
|
||||
## Examples
|
||||
|
||||
```sh
|
||||
@@ -71,3 +91,26 @@ API-backed adapters read tokens from env vars referenced by the config
|
||||
export REPLICATE_API_TOKEN=...
|
||||
imagen generate "a cat" --backend flux-dev-replicate
|
||||
```
|
||||
|
||||
## Cost-tracking (Replicate)
|
||||
|
||||
Successful generations through the Replicate adapter write one row to
|
||||
`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost
|
||||
estimate, prompt sha256 hash (never the prompt itself), and the caller
|
||||
identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`).
|
||||
|
||||
The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are
|
||||
unset, or the database write fails, the image still lands and the CLI
|
||||
prints a warning to stderr.
|
||||
|
||||
Inspect spend:
|
||||
|
||||
```sh
|
||||
imagen usage # all rows, grouped by week + backend + model + caller
|
||||
imagen usage --since 2026-05-01 # only rows on/after a UTC date
|
||||
imagen usage --since 2026-05-01 --raw
|
||||
```
|
||||
|
||||
Per-model rates live in `internal/backend/replicate_pricing.go` — they
|
||||
are snapshotted from <https://replicate.com/pricing> and refreshed on a
|
||||
quarterly cadence.
|
||||
|
||||
557
internal/backend/comfyui.go
Normal file
557
internal/backend/comfyui.go
Normal file
@@ -0,0 +1,557 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ComfyType is the type-name adapters register under for ComfyUI instances.
|
||||
const ComfyType = "comfyui"
|
||||
|
||||
// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history`
|
||||
// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from
|
||||
// the values in Request.
|
||||
//
|
||||
// Concurrency: a single Comfy is safe to share across goroutines as long as
|
||||
// the underlying http.Client is. Generate does not hold long-lived state.
|
||||
type Comfy struct {
|
||||
instance string
|
||||
|
||||
base string
|
||||
model string
|
||||
vae string
|
||||
clipL string
|
||||
clipT5 string
|
||||
dtype string
|
||||
|
||||
defaultSteps int
|
||||
defaultSampler string
|
||||
defaultScheduler string
|
||||
|
||||
httpClient *http.Client
|
||||
pollInterval time.Duration
|
||||
pollTimeout time.Duration
|
||||
|
||||
// Hooks for tests; production paths use the package-level defaults.
|
||||
randSeed func() int64
|
||||
clientIDFn func() string
|
||||
}
|
||||
|
||||
// NewComfy is the registry constructor. cfg is the adapter's slice of
|
||||
// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX
|
||||
// schnell defaults.
|
||||
func NewComfy(name string, cfg map[string]any) (Backend, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("comfyui: empty instance name")
|
||||
}
|
||||
base := strings.TrimRight(getString(cfg, "base_url", ""), "/")
|
||||
if base == "" {
|
||||
return nil, fmt.Errorf("comfyui[%s]: base_url is required", name)
|
||||
}
|
||||
if _, err := url.Parse(base); err != nil {
|
||||
return nil, fmt.Errorf("comfyui[%s]: base_url %q invalid: %w", name, base, err)
|
||||
}
|
||||
model := getString(cfg, "model", "")
|
||||
if model == "" {
|
||||
return nil, fmt.Errorf("comfyui[%s]: model is required", name)
|
||||
}
|
||||
|
||||
c := &Comfy{
|
||||
instance: name,
|
||||
base: base,
|
||||
model: model,
|
||||
|
||||
vae: getString(cfg, "vae", "ae.safetensors"),
|
||||
clipL: getString(cfg, "clip_l", "clip_l.safetensors"),
|
||||
clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"),
|
||||
dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"),
|
||||
|
||||
defaultSteps: getInt(cfg, "default_steps", 4),
|
||||
defaultSampler: getString(cfg, "default_sampler", "euler"),
|
||||
defaultScheduler: getString(cfg, "default_scheduler", "simple"),
|
||||
|
||||
httpClient: &http.Client{Timeout: 60 * time.Second},
|
||||
pollInterval: 250 * time.Millisecond,
|
||||
pollTimeout: 120 * time.Second,
|
||||
|
||||
randSeed: cryptoSeed,
|
||||
clientIDFn: randClientID,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Name returns the instance name from imagen.yaml.
|
||||
func (c *Comfy) Name() string { return c.instance }
|
||||
|
||||
// Generate submits one workflow to ComfyUI, waits for it to render, and
|
||||
// returns the resulting PNG.
|
||||
func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
|
||||
width := orDefaultInt(req.Width, 1024)
|
||||
height := orDefaultInt(req.Height, 1024)
|
||||
steps := orDefaultInt(req.Steps, c.defaultSteps)
|
||||
|
||||
sampler := c.defaultSampler
|
||||
scheduler := c.defaultScheduler
|
||||
if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" {
|
||||
sampler = v
|
||||
}
|
||||
if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" {
|
||||
scheduler = v
|
||||
}
|
||||
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = c.randSeed()
|
||||
}
|
||||
|
||||
workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler)
|
||||
clientID := c.clientIDFn()
|
||||
|
||||
start := time.Now()
|
||||
promptID, err := c.submitPrompt(ctx, workflow, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filename, err := c.waitForCompletion(ctx, promptID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imgBytes, err := c.fetchImage(ctx, filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
latencyMs := time.Since(start).Milliseconds()
|
||||
|
||||
meta := map[string]any{
|
||||
"backend": c.instance,
|
||||
"backend_type": ComfyType,
|
||||
"model": c.model,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"sampler": sampler,
|
||||
"scheduler": scheduler,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"latency_ms": latencyMs,
|
||||
"prompt_id": promptID,
|
||||
"client_id": clientID,
|
||||
}
|
||||
if vram := c.vramUsedMiB(ctx); vram > 0 {
|
||||
meta["vram_used_mib"] = vram
|
||||
}
|
||||
|
||||
return &Result{
|
||||
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
|
||||
MimeType: "image/png",
|
||||
Metadata: meta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// submitPrompt POSTs the workflow and extracts the prompt_id.
|
||||
//
|
||||
// Retries once on a 5xx or transient network error. 4xx responses are not
|
||||
// retried — they are treated as configuration bugs (missing model, bad
|
||||
// workflow shape, etc.) and surfaced with a hint pointing at the docs when
|
||||
// the body matches a known pattern.
|
||||
func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clientID string) (string, error) {
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"prompt": workflow,
|
||||
"client_id": clientID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("comfyui: marshal workflow: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := range 2 {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/prompt", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = c.connError(err)
|
||||
continue
|
||||
}
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
switch {
|
||||
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
||||
return parsePromptID(respBody, c.model)
|
||||
case resp.StatusCode >= 500:
|
||||
lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody))
|
||||
continue
|
||||
default:
|
||||
return "", c.classifyBadRequest(resp.StatusCode, respBody)
|
||||
}
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
// waitForCompletion polls /history/{id} until the prompt finishes and
|
||||
// returns the filename of the produced image.
|
||||
func (c *Comfy) waitForCompletion(ctx context.Context, promptID string) (string, error) {
|
||||
deadline := time.Now().Add(c.pollTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return "", fmt.Errorf("comfyui: prompt %s did not complete within %s", promptID, c.pollTimeout)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/history/"+promptID, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", c.connError(err)
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("comfyui /history/%s %d: %s", promptID, resp.StatusCode, snip(body))
|
||||
}
|
||||
filename, done, err := parseHistory(body, promptID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if done {
|
||||
return filename, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(c.pollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchImage downloads the produced image bytes via /view.
|
||||
func (c *Comfy) fetchImage(ctx context.Context, filename string) ([]byte, error) {
|
||||
q := url.Values{
|
||||
"filename": {filename},
|
||||
"type": {"output"},
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/view?"+q.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, c.connError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("comfyui /view %d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// vramUsedMiB returns total - free VRAM on device 0 from /system_stats, or
|
||||
// 0 if the endpoint isn't available. Best-effort, never an error.
|
||||
func (c *Comfy) vramUsedMiB(ctx context.Context) int64 {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/system_stats", nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return 0
|
||||
}
|
||||
var s struct {
|
||||
Devices []struct {
|
||||
VRAMTotal int64 `json:"vram_total"`
|
||||
VRAMFree int64 `json:"vram_free"`
|
||||
} `json:"devices"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
|
||||
return 0
|
||||
}
|
||||
if len(s.Devices) == 0 {
|
||||
return 0
|
||||
}
|
||||
used := s.Devices[0].VRAMTotal - s.Devices[0].VRAMFree
|
||||
if used < 0 {
|
||||
return 0
|
||||
}
|
||||
return used / (1024 * 1024)
|
||||
}
|
||||
|
||||
// connError translates a Go networking error into a user-actionable message,
|
||||
// pointing at the boot-whitetower script when mRock looks asleep.
|
||||
func (c *Comfy) connError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
msg := err.Error()
|
||||
var opErr *net.OpError
|
||||
asOp := errors.As(err, &opErr)
|
||||
switch {
|
||||
case asOp,
|
||||
strings.Contains(msg, "connection refused"),
|
||||
strings.Contains(msg, "no such host"),
|
||||
strings.Contains(msg, "no route to host"),
|
||||
strings.Contains(msg, "network is unreachable"),
|
||||
strings.Contains(msg, "i/o timeout"):
|
||||
return fmt.Errorf("comfyui at %s unreachable (%v) — if mRock is asleep, run: boot-whitetower mrock", c.base, err)
|
||||
}
|
||||
return fmt.Errorf("comfyui at %s: %w", c.base, err)
|
||||
}
|
||||
|
||||
// classifyBadRequest interprets a 4xx body. Some ComfyUI builds use 400 for
|
||||
// workflow-validation failures and put the diagnostics in node_errors; older
|
||||
// builds use 200 + node_errors. This handles the 4xx flavour.
|
||||
func (c *Comfy) classifyBadRequest(status int, body []byte) error {
|
||||
if hint, ok := missingModelHint(body, c.model); ok {
|
||||
return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint)
|
||||
}
|
||||
return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body))
|
||||
}
|
||||
|
||||
// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow,
|
||||
// node-IDs matching the upstream "flux-schnell" template so anyone debugging
|
||||
// in the ComfyUI UI sees a familiar shape.
|
||||
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any {
|
||||
return map[string]any{
|
||||
"6": map[string]any{
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": map[string]any{
|
||||
"text": prompt,
|
||||
"clip": []any{"11", 0},
|
||||
},
|
||||
},
|
||||
"8": map[string]any{
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": map[string]any{
|
||||
"samples": []any{"31", 0},
|
||||
"vae": []any{"10", 0},
|
||||
},
|
||||
},
|
||||
"9": map[string]any{
|
||||
"class_type": "SaveImage",
|
||||
"inputs": map[string]any{
|
||||
"filename_prefix": "imagen",
|
||||
"images": []any{"8", 0},
|
||||
},
|
||||
},
|
||||
"10": map[string]any{
|
||||
"class_type": "VAELoader",
|
||||
"inputs": map[string]any{"vae_name": c.vae},
|
||||
},
|
||||
"11": map[string]any{
|
||||
"class_type": "DualCLIPLoader",
|
||||
"inputs": map[string]any{
|
||||
"clip_name1": c.clipT5,
|
||||
"clip_name2": c.clipL,
|
||||
"type": "flux",
|
||||
},
|
||||
},
|
||||
"12": map[string]any{
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": map[string]any{
|
||||
"unet_name": c.model,
|
||||
"weight_dtype": c.dtype,
|
||||
},
|
||||
},
|
||||
"13": map[string]any{
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": map[string]any{
|
||||
"text": negative,
|
||||
"clip": []any{"11", 0},
|
||||
},
|
||||
},
|
||||
"27": map[string]any{
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": map[string]any{
|
||||
"width": w,
|
||||
"height": h,
|
||||
"batch_size": 1,
|
||||
},
|
||||
},
|
||||
"30": map[string]any{
|
||||
"class_type": "ModelSamplingFlux",
|
||||
"inputs": map[string]any{
|
||||
"model": []any{"12", 0},
|
||||
"max_shift": 1.15,
|
||||
"base_shift": 0.5,
|
||||
"width": w,
|
||||
"height": h,
|
||||
},
|
||||
},
|
||||
"31": map[string]any{
|
||||
"class_type": "KSampler",
|
||||
"inputs": map[string]any{
|
||||
"model": []any{"30", 0},
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"cfg": 1.0,
|
||||
"sampler_name": sampler,
|
||||
"scheduler": scheduler,
|
||||
"denoise": 1.0,
|
||||
"positive": []any{"6", 0},
|
||||
"negative": []any{"13", 0},
|
||||
"latent_image": []any{"27", 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a
|
||||
// validation failure and stuffs node_errors in the body — this function
|
||||
// turns that into the same user-facing error as a 4xx with the same body.
|
||||
func parsePromptID(body []byte, model string) (string, error) {
|
||||
var resp struct {
|
||||
PromptID string `json:"prompt_id"`
|
||||
NodeErrors map[string]any `json:"node_errors"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("comfyui /prompt: parse response: %w (body: %s)", err, snip(body))
|
||||
}
|
||||
if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 {
|
||||
if hint, ok := missingModelHint(body, model); ok {
|
||||
return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint)
|
||||
}
|
||||
return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body))
|
||||
}
|
||||
if resp.PromptID == "" {
|
||||
return "", fmt.Errorf("comfyui /prompt: empty prompt_id (body: %s)", snip(body))
|
||||
}
|
||||
return resp.PromptID, nil
|
||||
}
|
||||
|
||||
// parseHistory inspects a /history/{id} body and returns either the produced
|
||||
// filename + done=true, or done=false to signal "keep polling".
|
||||
func parseHistory(body []byte, promptID string) (string, bool, error) {
|
||||
var entries map[string]struct {
|
||||
Status struct {
|
||||
Completed bool `json:"completed"`
|
||||
StatusStr string `json:"status_str"`
|
||||
} `json:"status"`
|
||||
Outputs map[string]struct {
|
||||
Images []struct {
|
||||
Filename string `json:"filename"`
|
||||
Subfolder string `json:"subfolder"`
|
||||
Type string `json:"type"`
|
||||
} `json:"images"`
|
||||
} `json:"outputs"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &entries); err != nil {
|
||||
return "", false, fmt.Errorf("comfyui /history: parse: %w (body: %s)", err, snip(body))
|
||||
}
|
||||
e, ok := entries[promptID]
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
if e.Status.StatusStr == "error" {
|
||||
return "", false, fmt.Errorf("comfyui prompt %s errored: %s", promptID, snip(body))
|
||||
}
|
||||
if !e.Status.Completed {
|
||||
return "", false, nil
|
||||
}
|
||||
for _, out := range e.Outputs {
|
||||
if len(out.Images) > 0 {
|
||||
return out.Images[0].Filename, true, nil
|
||||
}
|
||||
}
|
||||
return "", true, fmt.Errorf("comfyui prompt %s completed but produced no images", promptID)
|
||||
}
|
||||
|
||||
// missingModelHint returns a user-actionable message when the response body
|
||||
// indicates the configured unet model isn't loaded on the server. ComfyUI
|
||||
// uses both the human-readable "Value not in list" message and the enum
|
||||
// "value_not_in_list" type — match either.
|
||||
func missingModelHint(body []byte, model string) (string, bool) {
|
||||
s := string(body)
|
||||
hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list")
|
||||
if hasMarker && strings.Contains(s, "unet_name") {
|
||||
return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func cryptoSeed() int64 {
|
||||
var b [8]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
return int64(binary.BigEndian.Uint64(b[:]) >> 1)
|
||||
}
|
||||
|
||||
func randClientID() string {
|
||||
var b [8]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return fmt.Sprintf("imagen-%x", b)
|
||||
}
|
||||
|
||||
func getString(m map[string]any, k, def string) string {
|
||||
if v, ok := m[k].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func getInt(m map[string]any, k string, def int) int {
|
||||
if v, ok := m[k]; ok {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func orDefaultInt(v, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func snip(b []byte) string {
|
||||
const max = 500
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > max {
|
||||
s = s[:max] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(ComfyType, NewComfy)
|
||||
}
|
||||
494
internal/backend/comfyui_test.go
Normal file
494
internal/backend/comfyui_test.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeComfy is a programmable mock of the ComfyUI HTTP API. Tests configure
|
||||
// its behaviour by adjusting the public fields before issuing the request.
|
||||
type fakeComfy struct {
|
||||
t *testing.T
|
||||
|
||||
// /prompt
|
||||
promptStatus int
|
||||
promptBody []byte
|
||||
promptCalls atomic.Int32
|
||||
failPromptUntil int32 // first N /prompt calls return promptFailStatus
|
||||
promptFailStatus int
|
||||
promptFailBody []byte
|
||||
|
||||
// /history — start by returning {} (no entry), flip to completed once
|
||||
// historyReadyAfter polls have happened.
|
||||
historyReadyAfter int32
|
||||
historyCalls atomic.Int32
|
||||
historyError bool
|
||||
|
||||
// /view
|
||||
viewStatus int
|
||||
viewBody []byte
|
||||
viewType string
|
||||
|
||||
// /system_stats
|
||||
statsTotal int64
|
||||
statsFree int64
|
||||
|
||||
server *httptest.Server
|
||||
}
|
||||
|
||||
func (f *fakeComfy) handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/prompt" && r.Method == http.MethodPost:
|
||||
n := f.promptCalls.Add(1)
|
||||
if n <= int32(f.failPromptUntil) {
|
||||
w.WriteHeader(f.promptFailStatus)
|
||||
_, _ = w.Write(f.promptFailBody)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(f.promptStatus)
|
||||
_, _ = w.Write(f.promptBody)
|
||||
case strings.HasPrefix(r.URL.Path, "/history/") && r.Method == http.MethodGet:
|
||||
n := f.historyCalls.Add(1)
|
||||
id := strings.TrimPrefix(r.URL.Path, "/history/")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f.historyError {
|
||||
_, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":false,"status_str":"error"},"outputs":{}}}`, id)
|
||||
return
|
||||
}
|
||||
if n <= f.historyReadyAfter {
|
||||
_, _ = w.Write([]byte(`{}`))
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(w,
|
||||
`{"%s":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`,
|
||||
id,
|
||||
)
|
||||
case r.URL.Path == "/view" && r.Method == http.MethodGet:
|
||||
ct := f.viewType
|
||||
if ct == "" {
|
||||
ct = "image/png"
|
||||
}
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(f.viewStatus)
|
||||
_, _ = w.Write(f.viewBody)
|
||||
case r.URL.Path == "/system_stats" && r.Method == http.MethodGet:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
body := map[string]any{
|
||||
"system": map[string]any{},
|
||||
"devices": []map[string]any{
|
||||
{"vram_total": f.statsTotal, "vram_free": f.statsFree},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
default:
|
||||
f.t.Errorf("fakeComfy: unexpected request %s %s", r.Method, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fakeComfy) start() {
|
||||
f.server = httptest.NewServer(f.handler())
|
||||
f.t.Cleanup(f.server.Close)
|
||||
}
|
||||
|
||||
// newFakeComfy spins up a fakeComfy with happy-path defaults.
|
||||
func newFakeComfy(t *testing.T) *fakeComfy {
|
||||
t.Helper()
|
||||
f := &fakeComfy{
|
||||
t: t,
|
||||
promptStatus: http.StatusOK,
|
||||
promptBody: []byte(`{"prompt_id":"pid-abc","number":1,"node_errors":{}}`),
|
||||
viewStatus: http.StatusOK,
|
||||
viewBody: mustPNG(t, 16, 16),
|
||||
statsTotal: 16 * 1024 * 1024 * 1024,
|
||||
statsFree: 8 * 1024 * 1024 * 1024,
|
||||
}
|
||||
f.start()
|
||||
return f
|
||||
}
|
||||
|
||||
// newComfy returns a Comfy pointed at f, with poll interval squashed for fast
|
||||
// tests and deterministic seed/client_id.
|
||||
func newComfy(t *testing.T, f *fakeComfy) *Comfy {
|
||||
t.Helper()
|
||||
be, err := NewComfy("flux-test", map[string]any{
|
||||
"base_url": f.server.URL,
|
||||
"model": "flux1-schnell.safetensors",
|
||||
"default_steps": 4,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewComfy: %v", err)
|
||||
}
|
||||
c := be.(*Comfy)
|
||||
c.pollInterval = time.Millisecond
|
||||
c.pollTimeout = 5 * time.Second
|
||||
c.randSeed = func() int64 { return 42 }
|
||||
c.clientIDFn = func() string { return "imagen-test" }
|
||||
return c
|
||||
}
|
||||
|
||||
func mustPNG(t *testing.T, w, h int) []byte {
|
||||
t.Helper()
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
img.Set(x, y, color.RGBA{R: 200, G: 100, B: 50, A: 255})
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("encode png: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestComfyConstructorRequiresBaseAndModel(t *testing.T) {
|
||||
if _, err := NewComfy("x", map[string]any{}); err == nil {
|
||||
t.Errorf("expected error for missing base_url")
|
||||
}
|
||||
if _, err := NewComfy("x", map[string]any{"base_url": "http://h:1"}); err == nil {
|
||||
t.Errorf("expected error for missing model")
|
||||
}
|
||||
if _, err := NewComfy("", map[string]any{"base_url": "http://h:1", "model": "m"}); err == nil {
|
||||
t.Errorf("expected error for empty instance name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyHappyPath(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.historyReadyAfter = 2 // exercise the polling loop
|
||||
c := newComfy(t, f)
|
||||
|
||||
res, err := c.Generate(context.Background(), Request{
|
||||
Prompt: "a small fishbowl with a cat",
|
||||
Width: 512,
|
||||
Height: 512,
|
||||
Steps: 4,
|
||||
Seed: 1234567,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
|
||||
if res.MimeType != "image/png" {
|
||||
t.Errorf("mime = %q", res.MimeType)
|
||||
}
|
||||
body, err := io.ReadAll(res.ImageReader)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
if !bytes.Equal(body, f.viewBody) {
|
||||
t.Errorf("image body did not round-trip")
|
||||
}
|
||||
|
||||
if seed, _ := res.Metadata["seed"].(int64); seed != 1234567 {
|
||||
t.Errorf("metadata seed = %v", res.Metadata["seed"])
|
||||
}
|
||||
if model, _ := res.Metadata["model"].(string); model != "flux1-schnell.safetensors" {
|
||||
t.Errorf("metadata model = %v", res.Metadata["model"])
|
||||
}
|
||||
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
|
||||
t.Errorf("metadata steps = %v", res.Metadata["steps"])
|
||||
}
|
||||
if pid, _ := res.Metadata["prompt_id"].(string); pid != "pid-abc" {
|
||||
t.Errorf("metadata prompt_id = %v", res.Metadata["prompt_id"])
|
||||
}
|
||||
if _, ok := res.Metadata["latency_ms"]; !ok {
|
||||
t.Errorf("metadata missing latency_ms")
|
||||
}
|
||||
// vram_used_mib is best-effort but should be present given our mock stats
|
||||
if vram, _ := res.Metadata["vram_used_mib"].(int64); vram != 8192 {
|
||||
t.Errorf("metadata vram_used_mib = %v, want 8192", res.Metadata["vram_used_mib"])
|
||||
}
|
||||
|
||||
if got := f.historyCalls.Load(); got < 3 {
|
||||
t.Errorf("expected at least 3 /history polls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyDefaultsAppliedWhenZero(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
c := newComfy(t, f)
|
||||
|
||||
res, err := c.Generate(context.Background(), Request{Prompt: "p"}) // all-zero
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
_, _ = io.ReadAll(res.ImageReader)
|
||||
|
||||
if w, _ := res.Metadata["width"].(int); w != 1024 {
|
||||
t.Errorf("width default = %v", res.Metadata["width"])
|
||||
}
|
||||
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
|
||||
t.Errorf("steps default = %v", res.Metadata["steps"])
|
||||
}
|
||||
if seed, _ := res.Metadata["seed"].(int64); seed != 42 {
|
||||
t.Errorf("seed default (test rand hook) = %v", res.Metadata["seed"])
|
||||
}
|
||||
if s, _ := res.Metadata["sampler"].(string); s != "euler" {
|
||||
t.Errorf("sampler default = %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyPromptRetriesOnce5xx(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.failPromptUntil = 1
|
||||
f.promptFailStatus = http.StatusBadGateway
|
||||
f.promptFailBody = []byte("upstream busy")
|
||||
c := newComfy(t, f)
|
||||
|
||||
res, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (with one 502 then OK): %v", err)
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
_, _ = io.ReadAll(res.ImageReader)
|
||||
|
||||
if got := f.promptCalls.Load(); got != 2 {
|
||||
t.Errorf("expected exactly 2 /prompt calls (1 fail + 1 retry), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyPromptGivesUpAfterTwo5xx(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.failPromptUntil = 99 // every call fails
|
||||
f.promptFailStatus = http.StatusServiceUnavailable
|
||||
f.promptFailBody = []byte("nope")
|
||||
c := newComfy(t, f)
|
||||
|
||||
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after sustained 503s")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "503") {
|
||||
t.Errorf("expected error to mention 503, got %v", err)
|
||||
}
|
||||
if got := f.promptCalls.Load(); got != 2 {
|
||||
t.Errorf("expected exactly 2 /prompt calls (no further retries), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyPromptDoesNotRetryOn4xx(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.failPromptUntil = 99
|
||||
f.promptFailStatus = http.StatusBadRequest
|
||||
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation"},"node_errors":{"some":"thing"}}`)
|
||||
c := newComfy(t, f)
|
||||
|
||||
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 400")
|
||||
}
|
||||
if got := f.promptCalls.Load(); got != 1 {
|
||||
t.Errorf("expected exactly 1 /prompt call (no retry on 4xx), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.failPromptUntil = 99
|
||||
f.promptFailStatus = http.StatusBadRequest
|
||||
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation","message":"Prompt outputs failed validation"},"node_errors":{"12":{"errors":[{"type":"value_not_in_list","message":"Value not in list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
|
||||
c := newComfy(t, f)
|
||||
|
||||
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") {
|
||||
t.Errorf("error should point at the setup doc, got %v", err)
|
||||
}
|
||||
if !strings.Contains(msg, "flux1-schnell.safetensors") {
|
||||
t.Errorf("error should name the missing model, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) {
|
||||
// Older ComfyUI builds 200 a workflow-validation failure.
|
||||
f := newFakeComfy(t)
|
||||
f.promptStatus = http.StatusOK
|
||||
f.promptBody = []byte(`{"prompt_id":"","node_errors":{"12":{"errors":[{"type":"value_not_in_list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
|
||||
c := newComfy(t, f)
|
||||
|
||||
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for node_errors on 200")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") {
|
||||
t.Errorf("error should point at the setup doc, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyHistoryErrorSurfaced(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.historyError = true
|
||||
c := newComfy(t, f)
|
||||
|
||||
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when history reports execution error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "errored") {
|
||||
t.Errorf("expected 'errored' in message, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyViewFailureSurfaced(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.viewStatus = http.StatusNotFound
|
||||
f.viewBody = []byte("nope")
|
||||
c := newComfy(t, f)
|
||||
|
||||
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when /view 404s")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "404") {
|
||||
t.Errorf("expected status code in error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyUnreachableHostMentionsBootHelper(t *testing.T) {
|
||||
be, err := NewComfy("flux-test", map[string]any{
|
||||
"base_url": "http://127.0.0.1:1", // closed port; connection refused
|
||||
"model": "flux1-schnell.safetensors",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewComfy: %v", err)
|
||||
}
|
||||
c := be.(*Comfy)
|
||||
c.httpClient = &http.Client{Timeout: 500 * time.Millisecond}
|
||||
|
||||
_, err = c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable host")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "boot-whitetower mrock") {
|
||||
t.Errorf("expected boot-helper hint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyContextCancelStopsPolling(t *testing.T) {
|
||||
f := newFakeComfy(t)
|
||||
f.historyReadyAfter = 1_000_000 // never finishes
|
||||
c := newComfy(t, f)
|
||||
c.pollInterval = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.Generate(ctx, Request{Prompt: "p", Width: 64, Height: 64})
|
||||
if err == nil {
|
||||
t.Fatal("expected ctx.Err()")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
t.Errorf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyWorkflowReflectsRequest(t *testing.T) {
|
||||
// Capture the workflow body to assert KSampler + EmptyLatentImage values.
|
||||
var captured []byte
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/prompt":
|
||||
captured, _ = io.ReadAll(r.Body)
|
||||
_, _ = w.Write([]byte(`{"prompt_id":"pid","number":1,"node_errors":{}}`))
|
||||
case "/history/pid":
|
||||
_, _ = w.Write([]byte(`{"pid":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`))
|
||||
case "/view":
|
||||
_, _ = w.Write(mustPNG(t, 8, 8))
|
||||
case "/system_stats":
|
||||
_, _ = w.Write([]byte(`{"devices":[{"vram_total":1,"vram_free":1}]}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewComfy("flux-test", map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"model": "custom.safetensors",
|
||||
"default_steps": 7,
|
||||
"default_sampler": "dpmpp_2m",
|
||||
"default_scheduler": "karras",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewComfy: %v", err)
|
||||
}
|
||||
c := be.(*Comfy)
|
||||
c.pollInterval = time.Millisecond
|
||||
c.randSeed = func() int64 { return 9999 }
|
||||
|
||||
res, err := c.Generate(context.Background(), Request{
|
||||
Prompt: "a cat",
|
||||
NegativePrompt: "blurry",
|
||||
Width: 768,
|
||||
Height: 512,
|
||||
Steps: 11,
|
||||
Seed: 555,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
|
||||
var sent struct {
|
||||
Prompt map[string]map[string]any `json:"prompt"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
if err := json.Unmarshal(captured, &sent); err != nil {
|
||||
t.Fatalf("unmarshal captured: %v", err)
|
||||
}
|
||||
ks := sent.Prompt["31"]["inputs"].(map[string]any)
|
||||
if ks["seed"].(float64) != 555 {
|
||||
t.Errorf("KSampler seed = %v, want 555", ks["seed"])
|
||||
}
|
||||
if ks["steps"].(float64) != 11 {
|
||||
t.Errorf("KSampler steps = %v, want 11", ks["steps"])
|
||||
}
|
||||
if ks["sampler_name"].(string) != "dpmpp_2m" {
|
||||
t.Errorf("sampler_name = %v", ks["sampler_name"])
|
||||
}
|
||||
if ks["scheduler"].(string) != "karras" {
|
||||
t.Errorf("scheduler = %v", ks["scheduler"])
|
||||
}
|
||||
latent := sent.Prompt["27"]["inputs"].(map[string]any)
|
||||
if latent["width"].(float64) != 768 || latent["height"].(float64) != 512 {
|
||||
t.Errorf("EmptySD3LatentImage size = %vx%v", latent["width"], latent["height"])
|
||||
}
|
||||
unet := sent.Prompt["12"]["inputs"].(map[string]any)
|
||||
if unet["unet_name"].(string) != "custom.safetensors" {
|
||||
t.Errorf("unet_name = %v", unet["unet_name"])
|
||||
}
|
||||
neg := sent.Prompt["13"]["inputs"].(map[string]any)
|
||||
if neg["text"].(string) != "blurry" {
|
||||
t.Errorf("negative prompt not threaded: %v", neg["text"])
|
||||
}
|
||||
if !strings.HasPrefix(sent.ClientID, "imagen-") && sent.ClientID == "" {
|
||||
t.Errorf("client_id should be set: %q", sent.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComfyTypeIsRegistered(t *testing.T) {
|
||||
if !Default.Has(ComfyType) {
|
||||
t.Errorf("comfyui type not registered in Default")
|
||||
}
|
||||
}
|
||||
567
internal/backend/replicate.go
Normal file
567
internal/backend/replicate.go
Normal file
@@ -0,0 +1,567 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReplicateType is the type-name adapters register under for Replicate
|
||||
// instances.
|
||||
const ReplicateType = "replicate"
|
||||
|
||||
// Replicate is the Replicate API adapter. It speaks the public REST API
|
||||
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
|
||||
// to submit, then polls /v1/predictions/{id} until the prediction
|
||||
// settles, then downloads the produced image.
|
||||
//
|
||||
// Concurrency: a single Replicate is safe to share across goroutines.
|
||||
type Replicate struct {
|
||||
instance string
|
||||
|
||||
apiBase string
|
||||
apiToken string
|
||||
tokenEnv string
|
||||
model string // "owner/name" or "owner/name:version-hash"
|
||||
owner string
|
||||
name string
|
||||
version string // optional; empty means "use the model-based predictions endpoint"
|
||||
defaultSteps int
|
||||
defaultAspect string
|
||||
|
||||
httpClient *http.Client
|
||||
pollInterval time.Duration
|
||||
pollTimeout time.Duration
|
||||
|
||||
// Hooks for tests; production paths use the package-level defaults.
|
||||
randSeed func() int64
|
||||
initialBackoff time.Duration
|
||||
|
||||
// Sink is where successful generations are recorded for cost-tracking.
|
||||
// nil means "do not record". The framework wires this up in the CLI;
|
||||
// adapter tests inject a fake.
|
||||
Sink UsageSink
|
||||
}
|
||||
|
||||
// UsageSink writes one row per successful generation. Implementations
|
||||
// should treat write failures as warnings, not errors — the image has
|
||||
// already landed on disk; failing the call would lose the artefact.
|
||||
type UsageSink interface {
|
||||
Record(ctx context.Context, row UsageRow) error
|
||||
}
|
||||
|
||||
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
|
||||
// prompt itself is intentionally NOT included — only the sha256 hash.
|
||||
type UsageRow struct {
|
||||
Backend string
|
||||
Model string
|
||||
Seed *int64
|
||||
PromptHash string
|
||||
LatencyMs int
|
||||
CostUSDEstimate *float64
|
||||
Caller string
|
||||
}
|
||||
|
||||
// NewReplicate is the registry constructor. cfg is the adapter's slice
|
||||
// of imagen.yaml.
|
||||
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("replicate: empty instance name")
|
||||
}
|
||||
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
|
||||
model := getString(cfg, "model", "")
|
||||
if model == "" {
|
||||
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
|
||||
}
|
||||
owner, modelName, version, err := parseModelRef(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
|
||||
}
|
||||
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
|
||||
pollTimeout := timeoutForModel(modelName)
|
||||
r := &Replicate{
|
||||
instance: name,
|
||||
apiBase: apiBase,
|
||||
tokenEnv: tokenEnv,
|
||||
apiToken: os.Getenv(tokenEnv),
|
||||
model: model,
|
||||
owner: owner,
|
||||
name: modelName,
|
||||
version: version,
|
||||
defaultSteps: getInt(cfg, "default_steps", 0),
|
||||
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
pollInterval: 500 * time.Millisecond,
|
||||
pollTimeout: pollTimeout,
|
||||
randSeed: cryptoSeed,
|
||||
initialBackoff: time.Second,
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Name returns the instance name from imagen.yaml.
|
||||
func (r *Replicate) Name() string { return r.instance }
|
||||
|
||||
// Generate submits one prediction and returns the resulting PNG.
|
||||
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
|
||||
if r.apiToken == "" {
|
||||
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
|
||||
}
|
||||
|
||||
width := orDefaultInt(req.Width, 1024)
|
||||
height := orDefaultInt(req.Height, 1024)
|
||||
aspect := computeAspectRatio(width, height, r.defaultAspect)
|
||||
|
||||
steps := orDefaultInt(req.Steps, r.defaultSteps)
|
||||
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = r.randSeed()
|
||||
}
|
||||
|
||||
input := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"aspect_ratio": aspect,
|
||||
"num_outputs": 1,
|
||||
"output_format": "png",
|
||||
"seed": seed,
|
||||
}
|
||||
if steps > 0 {
|
||||
input["num_inference_steps"] = steps
|
||||
}
|
||||
if req.NegativePrompt != "" {
|
||||
input["negative_prompt"] = req.NegativePrompt
|
||||
}
|
||||
maps.Copy(input, req.BackendOpts)
|
||||
|
||||
start := time.Now()
|
||||
pred, err := r.submitPrediction(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final, err := r.waitForCompletion(ctx, pred.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
|
||||
}
|
||||
imgURL, err := pickFirstOutputURL(final.Output)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
|
||||
}
|
||||
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
latencyMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
predictTime := final.Metrics.PredictTime
|
||||
costEst, costKnown := replicatePerImageUSD(r.model)
|
||||
|
||||
meta := map[string]any{
|
||||
"backend": r.instance,
|
||||
"backend_type": ReplicateType,
|
||||
"model": r.model,
|
||||
"model_version": final.Version,
|
||||
"prediction_id": pred.ID,
|
||||
"seed": seed,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"aspect_ratio": aspect,
|
||||
"latency_ms": int64(latencyMs),
|
||||
}
|
||||
if predictTime > 0 {
|
||||
meta["predict_time_seconds"] = predictTime
|
||||
}
|
||||
if costKnown {
|
||||
meta["cost_usd_estimate"] = costEst
|
||||
}
|
||||
|
||||
if r.Sink != nil {
|
||||
row := UsageRow{
|
||||
Backend: r.instance,
|
||||
Model: r.model,
|
||||
PromptHash: hashPrompt(req.Prompt),
|
||||
LatencyMs: latencyMs,
|
||||
Caller: ResolveCaller(),
|
||||
}
|
||||
row.Seed = new(int64)
|
||||
*row.Seed = seed
|
||||
if costKnown {
|
||||
c := costEst
|
||||
row.CostUSDEstimate = &c
|
||||
}
|
||||
if err := r.Sink.Record(ctx, row); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Result{
|
||||
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
|
||||
MimeType: mime,
|
||||
Metadata: meta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// replicatePrediction is what the REST API returns under /v1/predictions
|
||||
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
|
||||
// declared.
|
||||
type replicatePrediction struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Version string `json:"version"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Metrics struct {
|
||||
PredictTime float64 `json:"predict_time"`
|
||||
} `json:"metrics"`
|
||||
}
|
||||
|
||||
// submitPrediction creates a prediction and returns it. Picks the
|
||||
// model-based endpoint when no version was given (recommended for
|
||||
// Replicate's official models), and the legacy /v1/predictions otherwise.
|
||||
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
|
||||
var (
|
||||
reqURL string
|
||||
body []byte
|
||||
err error
|
||||
)
|
||||
if r.version == "" {
|
||||
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
|
||||
body, err = json.Marshal(map[string]any{"input": input})
|
||||
} else {
|
||||
reqURL = r.apiBase + "/v1/predictions"
|
||||
body, err = json.Marshal(map[string]any{
|
||||
"version": r.version,
|
||||
"input": input,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pred replicatePrediction
|
||||
if err := json.Unmarshal(respBody, &pred); err != nil {
|
||||
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
|
||||
}
|
||||
if pred.ID == "" {
|
||||
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
|
||||
}
|
||||
return &pred, nil
|
||||
}
|
||||
|
||||
// waitForCompletion polls /v1/predictions/{id} until the status is a
|
||||
// terminal value (succeeded, failed, canceled) or the timeout fires.
|
||||
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
|
||||
deadline := time.Now().Add(r.pollTimeout)
|
||||
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
|
||||
start := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
|
||||
}
|
||||
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pred replicatePrediction
|
||||
if err := json.Unmarshal(body, &pred); err != nil {
|
||||
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
|
||||
}
|
||||
switch pred.Status {
|
||||
case "succeeded":
|
||||
return &pred, nil
|
||||
case "failed":
|
||||
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
|
||||
case "canceled":
|
||||
return nil, fmt.Errorf("status=canceled")
|
||||
case "starting", "processing", "":
|
||||
default:
|
||||
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(r.pollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry executes a Replicate API request and applies the resilience
|
||||
// policy: 401 surfaces a clean message naming the env var; 429 retries
|
||||
// with exponential backoff up to three times; 5xx retries once; other
|
||||
// errors surface unchanged.
|
||||
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
|
||||
const max429Retries = 3
|
||||
backoff := r.initialBackoff
|
||||
if backoff <= 0 {
|
||||
backoff = time.Second
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+r.apiToken)
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
if attempt >= 1 {
|
||||
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
|
||||
}
|
||||
lastErr = err
|
||||
if !sleepCtx(ctx, backoff) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
switch {
|
||||
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
||||
return respBody, nil
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
|
||||
case resp.StatusCode == http.StatusTooManyRequests:
|
||||
if attempt >= max429Retries {
|
||||
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
|
||||
}
|
||||
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
|
||||
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
|
||||
if !sleepCtx(ctx, wait) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
backoff *= 2
|
||||
continue
|
||||
case resp.StatusCode >= 500:
|
||||
if attempt >= 1 {
|
||||
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
|
||||
}
|
||||
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
|
||||
if !sleepCtx(ctx, backoff) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
default:
|
||||
_ = lastErr
|
||||
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchImage downloads the rendered image from the Replicate-provided
|
||||
// CDN URL. One retry on a generic network error.
|
||||
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
|
||||
var lastErr error
|
||||
for attempt := range 2 {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, "", err
|
||||
}
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
lastErr = readErr
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
if resp.StatusCode >= 500 && attempt == 0 {
|
||||
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
|
||||
continue
|
||||
}
|
||||
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
mime := resp.Header.Get("Content-Type")
|
||||
if mime == "" {
|
||||
mime = "image/png"
|
||||
}
|
||||
return body, mime, nil
|
||||
}
|
||||
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
|
||||
}
|
||||
|
||||
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
|
||||
func parseModelRef(ref string) (owner, name, version string, err error) {
|
||||
rest := ref
|
||||
if i := strings.IndexByte(rest, ':'); i >= 0 {
|
||||
version = rest[i+1:]
|
||||
rest = rest[:i]
|
||||
}
|
||||
parts := strings.SplitN(rest, "/", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
|
||||
}
|
||||
return parts[0], parts[1], version, nil
|
||||
}
|
||||
|
||||
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
|
||||
// than schnell; everything else gets the dev timeout for safety.
|
||||
func timeoutForModel(name string) time.Duration {
|
||||
switch strings.ToLower(name) {
|
||||
case "flux-schnell":
|
||||
return 60 * time.Second
|
||||
default:
|
||||
return 120 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// pickFirstOutputURL extracts the first image URL from the output field.
|
||||
// Replicate returns either a single string or an array of strings for image
|
||||
// models — we accept both.
|
||||
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", fmt.Errorf("output is empty")
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
|
||||
return s, nil
|
||||
}
|
||||
var arr []string
|
||||
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
|
||||
return arr[0], nil
|
||||
}
|
||||
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
|
||||
}
|
||||
|
||||
// computeAspectRatio reduces width:height to a Replicate-supported aspect
|
||||
// ratio when the reduction lands on one of the canonical values; otherwise
|
||||
// returns fallback.
|
||||
func computeAspectRatio(w, h int, fallback string) string {
|
||||
if w <= 0 || h <= 0 {
|
||||
return fallback
|
||||
}
|
||||
g := gcd(w, h)
|
||||
a, b := w/g, h/g
|
||||
s := fmt.Sprintf("%d:%d", a, b)
|
||||
if isReplicateAspectRatio(s) {
|
||||
return s
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func isReplicateAspectRatio(s string) bool {
|
||||
switch s {
|
||||
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func gcd(a, b int) int {
|
||||
for b != 0 {
|
||||
a, b = b, a%b
|
||||
}
|
||||
if a < 0 {
|
||||
return -a
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
|
||||
// is intentionally never written to the cost-tracking table.
|
||||
func hashPrompt(p string) string {
|
||||
sum := sha256.Sum256([]byte(p))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// ResolveCaller returns the agent identity the cost-tracking row is
|
||||
// attributed to. Order of resolution mirrors the maimcp identity logic:
|
||||
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
|
||||
// "unknown".
|
||||
func ResolveCaller() string {
|
||||
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
|
||||
return v
|
||||
}
|
||||
if pane := os.Getenv("TMUX_PANE"); pane != "" {
|
||||
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
|
||||
if err == nil {
|
||||
if name := strings.TrimSpace(string(out)); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// backoffFor429 honours a Retry-After header (in seconds) when present
|
||||
// and within reason, otherwise falls back to the caller's backoff.
|
||||
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
|
||||
if retryAfter == "" {
|
||||
return fallback
|
||||
}
|
||||
d, err := time.ParseDuration(retryAfter + "s")
|
||||
if err != nil || d <= 0 {
|
||||
return fallback
|
||||
}
|
||||
if d > 30*time.Second {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func sleepCtx(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(d):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func bytesReader(b []byte) io.Reader {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
return bytes.NewReader(b)
|
||||
}
|
||||
|
||||
// shortPath strips the host so error messages don't leak the API base.
|
||||
func shortPath(u string) string {
|
||||
if i := strings.Index(u, "/v1/"); i >= 0 {
|
||||
return u[i:]
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(ReplicateType, NewReplicate)
|
||||
}
|
||||
42
internal/backend/replicate_pricing.go
Normal file
42
internal/backend/replicate_pricing.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
import "strings"
|
||||
|
||||
// Replicate pricing snapshot.
|
||||
//
|
||||
// Source: https://replicate.com/pricing and the per-model "Run" tab on
|
||||
// each model page. Replicate bills per second of GPU time, but the
|
||||
// black-forest-labs FLUX models also publish a flat per-image price for
|
||||
// the typical settings — that flat number is what we hardcode here.
|
||||
//
|
||||
// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the
|
||||
// rates drift more than ~10%, update the table and bump snapshotDate.
|
||||
const replicatePricingSnapshotDate = "2026-05-08"
|
||||
|
||||
// replicatePerImageUSD is the per-image cost estimate keyed by Replicate
|
||||
// model identifier ("owner/name", with any ":version" trimmed). Returns
|
||||
// the rate and true if the model is known, 0 and false otherwise — an
|
||||
// unknown model writes a row with NULL cost rather than a wrong number.
|
||||
func replicatePerImageUSD(model string) (float64, bool) {
|
||||
key := normalisePricingKey(model)
|
||||
switch key {
|
||||
case "black-forest-labs/flux-schnell":
|
||||
return 0.003, true
|
||||
case "black-forest-labs/flux-dev":
|
||||
return 0.025, true
|
||||
case "black-forest-labs/flux-pro":
|
||||
return 0.055, true
|
||||
case "black-forest-labs/flux-1.1-pro":
|
||||
return 0.040, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// normalisePricingKey strips the optional ":version" suffix and lowercases
|
||||
// the owner/name pair. "Owner/Name:hash" → "owner/name".
|
||||
func normalisePricingKey(model string) string {
|
||||
if i := strings.IndexByte(model, ':'); i >= 0 {
|
||||
model = model[:i]
|
||||
}
|
||||
return strings.ToLower(model)
|
||||
}
|
||||
675
internal/backend/replicate_test.go
Normal file
675
internal/backend/replicate_test.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeReplicate is a programmable mock of the Replicate REST API.
|
||||
type fakeReplicate struct {
|
||||
t *testing.T
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
// Responses for the predictions submission. If the request matches the
|
||||
// model-based path, modelEndpointHits increments; for /v1/predictions,
|
||||
// versionEndpointHits increments.
|
||||
createStatus int
|
||||
createBody []byte
|
||||
createCalls atomic.Int32
|
||||
|
||||
// Auth-fail policy: if true, every request to the prediction endpoints
|
||||
// returns 401 with a stock body.
|
||||
auth401 bool
|
||||
|
||||
// 429-then-OK policy: first N create calls return 429.
|
||||
create429Until int32
|
||||
retryAfter string
|
||||
|
||||
// Sequence of /predictions/{id} responses, walked in order; once
|
||||
// exhausted, the last entry is returned indefinitely.
|
||||
pollResponses []string
|
||||
pollIdx atomic.Int32
|
||||
pollCalls atomic.Int32
|
||||
|
||||
// Image download policy.
|
||||
imageStatus int
|
||||
imageBody []byte
|
||||
image5xxFirst int32
|
||||
imageCalls atomic.Int32
|
||||
|
||||
server *httptest.Server
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
|
||||
f.handleCreate(w, r)
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
|
||||
f.handleCreate(w, r)
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
|
||||
f.handlePoll(w, r)
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/img":
|
||||
f.handleImage(w, r)
|
||||
default:
|
||||
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
|
||||
n := f.createCalls.Add(1)
|
||||
if f.auth401 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
|
||||
return
|
||||
}
|
||||
if n <= f.create429Until {
|
||||
if f.retryAfter != "" {
|
||||
w.Header().Set("Retry-After", f.retryAfter)
|
||||
}
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(f.createStatus)
|
||||
_, _ = w.Write(f.createBody)
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
|
||||
f.pollCalls.Add(1)
|
||||
idx := int(f.pollIdx.Add(1)) - 1
|
||||
if idx >= len(f.pollResponses) {
|
||||
idx = len(f.pollResponses) - 1
|
||||
}
|
||||
if idx < 0 {
|
||||
http.Error(w, "no poll response configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(f.pollResponses[idx]))
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
|
||||
n := f.imageCalls.Add(1)
|
||||
if n <= f.image5xxFirst {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = w.Write([]byte("upstream unavailable"))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.WriteHeader(f.imageStatus)
|
||||
_, _ = w.Write(f.imageBody)
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) start() {
|
||||
f.server = httptest.NewServer(f.handler())
|
||||
f.t.Cleanup(f.server.Close)
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
|
||||
|
||||
func newFakeReplicate(t *testing.T) *fakeReplicate {
|
||||
t.Helper()
|
||||
f := &fakeReplicate{
|
||||
t: t,
|
||||
createStatus: http.StatusCreated,
|
||||
imageStatus: http.StatusOK,
|
||||
imageBody: mustPNG(t, 8, 8),
|
||||
}
|
||||
f.start()
|
||||
// Default happy-path responses now that the server URL is known.
|
||||
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
|
||||
f.pollResponses = []string{
|
||||
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
|
||||
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
|
||||
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
|
||||
t.Helper()
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"api_token_env": "TEST_REPLICATE_TOKEN",
|
||||
"model": model,
|
||||
"api_base": f.server.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake-token"
|
||||
r.pollInterval = time.Millisecond
|
||||
r.pollTimeout = 5 * time.Second
|
||||
r.initialBackoff = time.Millisecond
|
||||
r.randSeed = func() int64 { return 42 }
|
||||
return r
|
||||
}
|
||||
|
||||
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
|
||||
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
|
||||
t.Errorf("expected error for empty instance name")
|
||||
}
|
||||
if _, err := NewReplicate("x", map[string]any{}); err == nil {
|
||||
t.Errorf("expected error for missing model")
|
||||
}
|
||||
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
|
||||
t.Errorf("expected error for malformed model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.apiToken = "" // simulate the env var being unset
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when token is missing")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
|
||||
t.Errorf("error should name the env var: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
sink := &recordingSink{}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.Sink = sink
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{
|
||||
Prompt: "a tiny dragon",
|
||||
Width: 1024, Height: 1024,
|
||||
Seed: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
body, _ := io.ReadAll(res.ImageReader)
|
||||
if !bytes.Equal(body, f.imageBody) {
|
||||
t.Errorf("image body did not round-trip")
|
||||
}
|
||||
|
||||
if mime := res.MimeType; mime != "image/png" {
|
||||
t.Errorf("mime = %q", mime)
|
||||
}
|
||||
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
|
||||
t.Errorf("metadata model = %v", got)
|
||||
}
|
||||
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
|
||||
t.Errorf("metadata model_version = %v", got)
|
||||
}
|
||||
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
|
||||
t.Errorf("metadata predict_time_seconds = %v", got)
|
||||
}
|
||||
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
|
||||
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
|
||||
}
|
||||
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
|
||||
t.Errorf("aspect_ratio = %q", got)
|
||||
}
|
||||
|
||||
if got := f.pollCalls.Load(); got < 3 {
|
||||
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
|
||||
}
|
||||
|
||||
if len(sink.rows) != 1 {
|
||||
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
|
||||
}
|
||||
row := sink.rows[0]
|
||||
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
|
||||
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
|
||||
}
|
||||
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
|
||||
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
|
||||
}
|
||||
if len(row.PromptHash) != 64 {
|
||||
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
|
||||
}
|
||||
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
|
||||
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
|
||||
}
|
||||
if row.LatencyMs <= 0 {
|
||||
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
|
||||
var sentBody []byte
|
||||
var sentPath string
|
||||
mu := sync.Mutex{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
|
||||
mu.Lock()
|
||||
sentBody, _ = io.ReadAll(r.Body)
|
||||
sentPath = r.URL.Path
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
|
||||
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
|
||||
default:
|
||||
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"model": "black-forest-labs/flux-dev:abc123",
|
||||
"api_base": srv.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake"
|
||||
r.pollInterval = time.Millisecond
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if sentPath != "/v1/predictions" {
|
||||
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(sentBody, &body); err != nil {
|
||||
t.Fatalf("unmarshal sent body: %v", err)
|
||||
}
|
||||
if body["version"] != "abc123" {
|
||||
t.Errorf("expected version=abc123 in body, got %v", body["version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicate401SurfacesEnvHint(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.auth401 = true
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected 401 to surface as error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
|
||||
t.Errorf("error should name env var: %v", err)
|
||||
}
|
||||
if !strings.Contains(msg, "401") {
|
||||
t.Errorf("error should mention 401: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.create429Until = 2 // first two calls 429 then succeed
|
||||
f.retryAfter = "" // force the adapter's exp-backoff path
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.pollInterval = time.Millisecond
|
||||
|
||||
// Squash the backoff so the test is fast.
|
||||
r.httpClient = &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
res, err := r.Generate(ctx, Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate after 429s: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
if got := f.createCalls.Load(); got != 3 {
|
||||
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.create429Until = 99 // every call 429
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
_, err := r.Generate(ctx, Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after sustained 429s")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "429") {
|
||||
t.Errorf("expected 429 in error: %v", err)
|
||||
}
|
||||
// max429Retries=3 → 1 initial + 3 retries = 4 total
|
||||
if got := f.createCalls.Load(); got != 4 {
|
||||
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.pollResponses = []string{
|
||||
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
|
||||
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
|
||||
}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for failed prediction")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed") {
|
||||
t.Errorf("error should mention failure: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "NSFW") {
|
||||
t.Errorf("error should include the API error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
// Always return processing → adapter times out.
|
||||
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.pollInterval = 5 * time.Millisecond
|
||||
r.pollTimeout = 30 * time.Millisecond
|
||||
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "did not complete") {
|
||||
t.Errorf("expected 'did not complete' in error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "waited") {
|
||||
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.image5xxFirst = 1 // first download 502, second OK
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (download retry): %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
if got := f.imageCalls.Load(); got != 2 {
|
||||
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.image5xxFirst = 99 // every download fails
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after sustained image-download 5xx")
|
||||
}
|
||||
if got := f.imageCalls.Load(); got != 2 {
|
||||
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateContextCancelStopsPolling(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.pollInterval = 5 * time.Millisecond
|
||||
r.pollTimeout = 5 * time.Second
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err := r.Generate(ctx, Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected ctx error")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
t.Errorf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
|
||||
var captured map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var top struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &top)
|
||||
captured = top.Input
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"model": "black-forest-labs/flux-schnell",
|
||||
"api_base": srv.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake"
|
||||
r.pollInterval = time.Millisecond
|
||||
|
||||
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
|
||||
// captured input should still have been recorded.
|
||||
_, _ = r.Generate(context.Background(), Request{
|
||||
Prompt: "p",
|
||||
Width: 1024, Height: 1024,
|
||||
BackendOpts: map[string]any{
|
||||
"output_quality": 90,
|
||||
"go_fast": true,
|
||||
},
|
||||
})
|
||||
if captured == nil {
|
||||
t.Fatal("create endpoint not hit")
|
||||
}
|
||||
if captured["output_quality"] != float64(90) {
|
||||
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
|
||||
}
|
||||
if captured["go_fast"] != true {
|
||||
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
|
||||
}
|
||||
if captured["prompt"] != "p" {
|
||||
t.Errorf("prompt not threaded: %v", captured["prompt"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateOutputArrayAccepted(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.pollResponses = []string{
|
||||
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
|
||||
}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (output as array): %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
}
|
||||
|
||||
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
r := newReplicate(t, f, "stability-ai/sdxl")
|
||||
sink := &recordingSink{}
|
||||
r.Sink = sink
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (unknown model): %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
if _, present := res.Metadata["cost_usd_estimate"]; present {
|
||||
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
|
||||
}
|
||||
if len(sink.rows) != 1 {
|
||||
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
|
||||
}
|
||||
if sink.rows[0].CostUSDEstimate != nil {
|
||||
t.Errorf("expected nil cost in sink row for unknown model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("sink failure should not fail Generate: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
}
|
||||
|
||||
func TestReplicateDefaultStepsApplied(t *testing.T) {
|
||||
var captured map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var top struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &top)
|
||||
captured = top.Input
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"model": "black-forest-labs/flux-dev",
|
||||
"api_base": srv.URL,
|
||||
"default_steps": 28,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake"
|
||||
r.pollInterval = time.Millisecond
|
||||
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if captured == nil {
|
||||
t.Fatal("create endpoint not hit")
|
||||
}
|
||||
if captured["num_inference_steps"] != float64(28) {
|
||||
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeAspectRatio(t *testing.T) {
|
||||
cases := []struct {
|
||||
w, h int
|
||||
fallback string
|
||||
want string
|
||||
}{
|
||||
{1024, 1024, "1:1", "1:1"},
|
||||
{1920, 1080, "1:1", "16:9"},
|
||||
{2560, 1440, "1:1", "16:9"},
|
||||
{1024, 768, "1:1", "4:3"},
|
||||
{1024, 1280, "1:1", "4:5"},
|
||||
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
|
||||
{0, 1024, "1:1", "1:1"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := computeAspectRatio(c.w, c.h, c.fallback)
|
||||
if got != c.want {
|
||||
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef(t *testing.T) {
|
||||
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
|
||||
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
|
||||
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
|
||||
}
|
||||
owner, name, ver, err = parseModelRef("owner/name:hash123")
|
||||
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
|
||||
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
|
||||
}
|
||||
if _, _, _, err := parseModelRef("noslash"); err == nil {
|
||||
t.Errorf("expected error for malformed ref")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPromptStable(t *testing.T) {
|
||||
a := hashPrompt("hello")
|
||||
b := hashPrompt("hello")
|
||||
c := hashPrompt("hello!")
|
||||
if a != b {
|
||||
t.Errorf("hashPrompt should be deterministic")
|
||||
}
|
||||
if a == c {
|
||||
t.Errorf("different prompts should hash differently")
|
||||
}
|
||||
if len(a) != 64 {
|
||||
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicatePricingKnownModels(t *testing.T) {
|
||||
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
|
||||
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
|
||||
}
|
||||
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
|
||||
t.Errorf("dev rate = %v (ok=%v)", v, ok)
|
||||
}
|
||||
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
|
||||
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
|
||||
}
|
||||
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
|
||||
t.Errorf("unknown model should report ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateTypeIsRegistered(t *testing.T) {
|
||||
if !Default.Has(ReplicateType) {
|
||||
t.Errorf("replicate type not registered in Default")
|
||||
}
|
||||
}
|
||||
|
||||
// recordingSink captures rows for assertion.
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
rows []UsageRow
|
||||
}
|
||||
|
||||
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.rows = append(s.rows, row)
|
||||
return nil
|
||||
}
|
||||
|
||||
type sinkFunc func(context.Context, UsageRow) error
|
||||
|
||||
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }
|
||||
@@ -19,11 +19,16 @@ type Config struct {
|
||||
Backends map[string]BackendSpec `yaml:"backends"`
|
||||
}
|
||||
|
||||
// OutputConfig controls where generated images and metadata sidecars land.
|
||||
// OutputConfig controls where generated images and metadata sidecars land,
|
||||
// and whether `imagen generate` opens a tmux preview window.
|
||||
type OutputConfig struct {
|
||||
Directory string `yaml:"directory"`
|
||||
Naming string `yaml:"naming"`
|
||||
WriteMetadataJSON bool `yaml:"write_metadata_json"`
|
||||
// Preview is the tri-state preview mode: "auto" (default), "on", "off".
|
||||
// Empty / unset is treated as "auto". $IMAGEN_PREVIEW and the
|
||||
// --preview/--no-preview flags override this in turn.
|
||||
Preview string `yaml:"preview"`
|
||||
}
|
||||
|
||||
// BackendSpec is one entry under `backends:`. Type identifies the adapter;
|
||||
@@ -78,6 +83,11 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("default_backend %q is not defined under backends:", c.DefaultBackend)
|
||||
}
|
||||
}
|
||||
switch c.Output.Preview {
|
||||
case "", "auto", "on", "off":
|
||||
default:
|
||||
return fmt.Errorf("output.preview = %q (must be auto|on|off)", c.Output.Preview)
|
||||
}
|
||||
for name, spec := range c.Backends {
|
||||
if name == "" {
|
||||
return errors.New("empty backend name")
|
||||
@@ -95,28 +105,45 @@ const Sample = `# imagen.yaml — config for the imagen CLI.
|
||||
# implementing the Backend interface, registering its type name, and listing
|
||||
# an instance here.
|
||||
|
||||
default_backend: mock
|
||||
default_backend: flux-schnell-local
|
||||
|
||||
output:
|
||||
directory: ~/Pictures/imagen
|
||||
naming: "{date}-{slug}-{seed}.png"
|
||||
write_metadata_json: true
|
||||
# Open a tmux window with tmux-img after a successful generation.
|
||||
# auto (default): preview iff stdout is a TTY and $TMUX is set.
|
||||
# on: always preview (errors outside a tmux session).
|
||||
# off: never preview (use this for batch / CI callers).
|
||||
preview: auto
|
||||
|
||||
backends:
|
||||
mock:
|
||||
type: mock
|
||||
|
||||
flux-schnell-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
# Filename of the unet checkpoint inside the ComfyUI server's
|
||||
# models/unet/ directory. See docs/setup-comfyui-mrock.md.
|
||||
model: flux1-schnell.safetensors
|
||||
default_steps: 4
|
||||
default_sampler: euler
|
||||
default_scheduler: simple
|
||||
|
||||
mock:
|
||||
type: mock
|
||||
|
||||
flux-schnell-replicate:
|
||||
type: replicate
|
||||
api_token_env: REPLICATE_API_TOKEN
|
||||
model: black-forest-labs/flux-schnell
|
||||
default_steps: 4
|
||||
default_aspect_ratio: "1:1"
|
||||
|
||||
flux-dev-replicate:
|
||||
type: replicate
|
||||
api_token_env: REPLICATE_API_TOKEN
|
||||
model: black-forest-labs/flux-dev
|
||||
default_steps: 28
|
||||
default_aspect_ratio: "1:1"
|
||||
|
||||
dalle3:
|
||||
type: openai
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestLoadAndValidate(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if cfg.DefaultBackend != "mock" {
|
||||
if cfg.DefaultBackend != "flux-schnell-local" {
|
||||
t.Errorf("default = %q", cfg.DefaultBackend)
|
||||
}
|
||||
mock, ok := cfg.Backends["mock"]
|
||||
@@ -30,9 +30,15 @@ func TestLoadAndValidate(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("flux backend missing")
|
||||
}
|
||||
if flux.Type != "comfyui" {
|
||||
t.Errorf("flux type = %q", flux.Type)
|
||||
}
|
||||
if flux.Raw["base_url"] != "http://mrock:8188" {
|
||||
t.Errorf("flux base_url = %v", flux.Raw["base_url"])
|
||||
}
|
||||
if flux.Raw["model"] != "flux1-schnell.safetensors" {
|
||||
t.Errorf("flux model = %v", flux.Raw["model"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsUnknownDefault(t *testing.T) {
|
||||
@@ -54,6 +60,34 @@ func TestValidateRejectsMissingType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePreviewMode(t *testing.T) {
|
||||
for _, mode := range []string{"", "auto", "on", "off"} {
|
||||
c := &Config{Output: OutputConfig{Preview: mode}}
|
||||
if err := c.Validate(); err != nil {
|
||||
t.Errorf("preview=%q: unexpected error %v", mode, err)
|
||||
}
|
||||
}
|
||||
bad := &Config{Output: OutputConfig{Preview: "yes"}}
|
||||
if err := bad.Validate(); err == nil {
|
||||
t.Errorf("expected error for invalid preview value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleParsesPreviewAuto(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "imagen.yaml")
|
||||
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
|
||||
t.Fatalf("write sample: %v", err)
|
||||
}
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if cfg.Output.Preview != "auto" {
|
||||
t.Errorf("Output.Preview = %q, want auto", cfg.Output.Preview)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandPath(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
cases := map[string]string{
|
||||
|
||||
119
internal/preview/tmux.go
Normal file
119
internal/preview/tmux.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Package preview opens a tmux window showing a generated image via tmux-img.
|
||||
// Mode resolution and the actual spawn are kept separate so the CLI can
|
||||
// decide-then-act and tests can drive each half independently.
|
||||
package preview
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Mode is the tri-state preview setting: auto (default), on (force), off.
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeAuto Mode = "auto"
|
||||
ModeOn Mode = "on"
|
||||
ModeOff Mode = "off"
|
||||
)
|
||||
|
||||
// ParseMode normalises a string into a Mode. Empty parses to ModeAuto so
|
||||
// callers can pass through unset config / env values.
|
||||
func ParseMode(s string) (Mode, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "auto":
|
||||
return ModeAuto, nil
|
||||
case "on":
|
||||
return ModeOn, nil
|
||||
case "off":
|
||||
return ModeOff, nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid preview mode %q (auto|on|off)", s)
|
||||
}
|
||||
|
||||
// Decision is the answer to "should we preview, and why".
|
||||
type Decision struct {
|
||||
ShouldPreview bool
|
||||
Reason string
|
||||
}
|
||||
|
||||
// Resolve maps (mode, runtime context) to a Decision.
|
||||
//
|
||||
// - off -> never preview
|
||||
// - on -> preview, but error if not in tmux (forced on outside tmux)
|
||||
// - auto -> preview iff inTmux && stdoutTTY
|
||||
func Resolve(mode Mode, inTmux, stdoutTTY bool) (Decision, error) {
|
||||
switch mode {
|
||||
case ModeOff:
|
||||
return Decision{ShouldPreview: false, Reason: "preview=off"}, nil
|
||||
case ModeOn:
|
||||
if !inTmux {
|
||||
return Decision{}, ErrNoTmuxForced
|
||||
}
|
||||
return Decision{ShouldPreview: true, Reason: "preview=on"}, nil
|
||||
case ModeAuto, "":
|
||||
if !inTmux {
|
||||
return Decision{ShouldPreview: false, Reason: "auto: $TMUX unset"}, nil
|
||||
}
|
||||
if !stdoutTTY {
|
||||
return Decision{ShouldPreview: false, Reason: "auto: stdout not a tty"}, nil
|
||||
}
|
||||
return Decision{ShouldPreview: true, Reason: "auto"}, nil
|
||||
}
|
||||
return Decision{}, fmt.Errorf("invalid preview mode %q", mode)
|
||||
}
|
||||
|
||||
// Errors returned by Spawn and Resolve. Each names the missing piece and,
|
||||
// where relevant, where to install it.
|
||||
var (
|
||||
ErrTmuxMissing = errors.New("tmux: binary not found on $PATH (required for image preview)")
|
||||
ErrTmuxImgMissing = errors.New("tmux-img: binary not found on $PATH (install at ~/.local/bin/tmux-img)")
|
||||
ErrNoTmuxForced = errors.New("--preview requires $TMUX (are you in a tmux session?)")
|
||||
)
|
||||
|
||||
// Spawner spawns the tmux preview window. The exec.LookPath / cmd.Run hooks
|
||||
// exist so tests can inject fakes without touching $PATH.
|
||||
type Spawner struct {
|
||||
LookPath func(string) (string, error)
|
||||
Run func(*exec.Cmd) error
|
||||
}
|
||||
|
||||
// Spawn opens a new tmux window named img:<slug> running tmux-img --hold
|
||||
// <imagePath>. -d keeps focus in the current pane. Caller is expected to
|
||||
// have already verified that we are inside a tmux session.
|
||||
func (s *Spawner) Spawn(imagePath, slug string) error {
|
||||
look := s.LookPath
|
||||
if look == nil {
|
||||
look = exec.LookPath
|
||||
}
|
||||
run := s.Run
|
||||
if run == nil {
|
||||
run = func(c *exec.Cmd) error { return c.Run() }
|
||||
}
|
||||
|
||||
tmuxBin, err := look("tmux")
|
||||
if err != nil {
|
||||
return ErrTmuxMissing
|
||||
}
|
||||
tmuxImgBin, err := look("tmux-img")
|
||||
if err != nil {
|
||||
return ErrTmuxImgMissing
|
||||
}
|
||||
|
||||
name := "img:" + slug
|
||||
shellCmd := fmt.Sprintf("%s --hold %s",
|
||||
shellQuote(tmuxImgBin), shellQuote(imagePath))
|
||||
cmd := exec.Command(tmuxBin, "new-window", "-d", "-n", name, shellCmd)
|
||||
if err := run(cmd); err != nil {
|
||||
return fmt.Errorf("tmux new-window: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shellQuote single-quotes s for /bin/sh — tmux passes the trailing arg of
|
||||
// new-window through a shell.
|
||||
func shellQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
|
||||
}
|
||||
170
internal/preview/tmux_test.go
Normal file
170
internal/preview/tmux_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package preview
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMode(t *testing.T) {
|
||||
cases := map[string]Mode{
|
||||
"": ModeAuto,
|
||||
"auto": ModeAuto,
|
||||
"AUTO": ModeAuto,
|
||||
"on": ModeOn,
|
||||
" on ": ModeOn,
|
||||
"off": ModeOff,
|
||||
}
|
||||
for in, want := range cases {
|
||||
got, err := ParseMode(in)
|
||||
if err != nil {
|
||||
t.Errorf("ParseMode(%q) err = %v", in, err)
|
||||
continue
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("ParseMode(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
if _, err := ParseMode("nope"); err == nil {
|
||||
t.Errorf("ParseMode(nope) should have errored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve(t *testing.T) {
|
||||
type tc struct {
|
||||
mode Mode
|
||||
inTmux bool
|
||||
stdoutTTY bool
|
||||
want bool
|
||||
wantErr error
|
||||
}
|
||||
cases := map[string]tc{
|
||||
"off-anywhere": {ModeOff, false, false, false, nil},
|
||||
"off-in-tmux-tty": {ModeOff, true, true, false, nil},
|
||||
"on-in-tmux": {ModeOn, true, false, true, nil},
|
||||
"on-outside-tmux-errs": {ModeOn, false, true, false, ErrNoTmuxForced},
|
||||
"auto-no-tmux": {ModeAuto, false, true, false, nil},
|
||||
"auto-tmux-no-tty": {ModeAuto, true, false, false, nil},
|
||||
"auto-tmux-and-tty": {ModeAuto, true, true, true, nil},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d, err := Resolve(c.mode, c.inTmux, c.stdoutTTY)
|
||||
if c.wantErr != nil {
|
||||
if !errors.Is(err, c.wantErr) {
|
||||
t.Fatalf("err = %v, want %v", err, c.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
if d.ShouldPreview != c.want {
|
||||
t.Errorf("ShouldPreview = %v, want %v (reason: %s)", d.ShouldPreview, c.want, d.Reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_BuildsCorrectCommand(t *testing.T) {
|
||||
var captured *exec.Cmd
|
||||
s := &Spawner{
|
||||
LookPath: func(name string) (string, error) {
|
||||
switch name {
|
||||
case "tmux":
|
||||
return "/usr/bin/tmux", nil
|
||||
case "tmux-img":
|
||||
return "/home/m/.local/bin/tmux-img", nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
},
|
||||
Run: func(c *exec.Cmd) error {
|
||||
captured = c
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := s.Spawn("/tmp/imagen/cat.png", "cat-in-a-fishbowl"); err != nil {
|
||||
t.Fatalf("Spawn: %v", err)
|
||||
}
|
||||
if captured == nil {
|
||||
t.Fatal("Run was not called")
|
||||
}
|
||||
if captured.Path != "/usr/bin/tmux" {
|
||||
t.Errorf("Path = %q, want /usr/bin/tmux", captured.Path)
|
||||
}
|
||||
args := captured.Args
|
||||
if len(args) < 6 {
|
||||
t.Fatalf("args = %v (need at least 6)", args)
|
||||
}
|
||||
// tmux new-window -d -n img:<slug> '<shell-cmd>'
|
||||
if args[1] != "new-window" {
|
||||
t.Errorf("args[1] = %q, want new-window", args[1])
|
||||
}
|
||||
if args[2] != "-d" {
|
||||
t.Errorf("args[2] = %q, want -d", args[2])
|
||||
}
|
||||
if args[3] != "-n" {
|
||||
t.Errorf("args[3] = %q, want -n", args[3])
|
||||
}
|
||||
if args[4] != "img:cat-in-a-fishbowl" {
|
||||
t.Errorf("args[4] = %q, want img:cat-in-a-fishbowl", args[4])
|
||||
}
|
||||
shellCmd := args[5]
|
||||
if !strings.Contains(shellCmd, "tmux-img") || !strings.Contains(shellCmd, "--hold") || !strings.Contains(shellCmd, "/tmp/imagen/cat.png") {
|
||||
t.Errorf("shell cmd %q missing expected pieces", shellCmd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_PathWithSpacesAndQuotes(t *testing.T) {
|
||||
var captured *exec.Cmd
|
||||
s := &Spawner{
|
||||
LookPath: func(name string) (string, error) {
|
||||
if name == "tmux" {
|
||||
return "/usr/bin/tmux", nil
|
||||
}
|
||||
if name == "tmux-img" {
|
||||
return "/usr/local/bin/tmux-img", nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
},
|
||||
Run: func(c *exec.Cmd) error { captured = c; return nil },
|
||||
}
|
||||
weird := "/tmp/imagen/o'malley's cat.png"
|
||||
if err := s.Spawn(weird, "slug"); err != nil {
|
||||
t.Fatalf("Spawn: %v", err)
|
||||
}
|
||||
shellCmd := captured.Args[5]
|
||||
// Single-quoted with the embedded apostrophe escaped via the
|
||||
// '\'' shell idiom — confirm we did not just splice the raw path.
|
||||
if strings.Contains(shellCmd, "o'malley's") {
|
||||
t.Errorf("shell cmd %q contains unescaped apostrophes", shellCmd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_MissingTmux(t *testing.T) {
|
||||
s := &Spawner{
|
||||
LookPath: func(string) (string, error) { return "", exec.ErrNotFound },
|
||||
Run: func(*exec.Cmd) error { return nil },
|
||||
}
|
||||
err := s.Spawn("/x.png", "s")
|
||||
if !errors.Is(err, ErrTmuxMissing) {
|
||||
t.Errorf("err = %v, want ErrTmuxMissing", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_MissingTmuxImg(t *testing.T) {
|
||||
s := &Spawner{
|
||||
LookPath: func(name string) (string, error) {
|
||||
if name == "tmux" {
|
||||
return "/usr/bin/tmux", nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
},
|
||||
Run: func(*exec.Cmd) error { return nil },
|
||||
}
|
||||
err := s.Spawn("/x.png", "s")
|
||||
if !errors.Is(err, ErrTmuxImgMissing) {
|
||||
t.Errorf("err = %v, want ErrTmuxImgMissing", err)
|
||||
}
|
||||
}
|
||||
160
internal/usage/usage.go
Normal file
160
internal/usage/usage.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Package usage records per-call cost-tracking rows for the imagen CLI
|
||||
// to mai.imagen_usage on Supabase. The writer is best-effort by design —
|
||||
// the calling adapter logs failures and proceeds, because the image
|
||||
// itself has already landed on disk by the time we record.
|
||||
package usage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
)
|
||||
|
||||
// Default REST schema is the mai schema where mai.imagen_usage lives.
|
||||
const supabaseSchema = "mai"
|
||||
|
||||
// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/
|
||||
// Content-Profile headers to target the mai schema instead of public.
|
||||
type SupabaseSink struct {
|
||||
URL string // SUPABASE_URL — e.g. https://msup.msbls.de
|
||||
APIKey string // SUPABASE_SERVICE_KEY
|
||||
HTTP *http.Client
|
||||
}
|
||||
|
||||
// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY
|
||||
// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use.
|
||||
// Returns nil + ok=false if the env vars are not configured — the CLI
|
||||
// uses that to skip cost-tracking gracefully.
|
||||
func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) {
|
||||
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
|
||||
if u == "" {
|
||||
return nil, false
|
||||
}
|
||||
key := os.Getenv("SUPABASE_SERVICE_KEY")
|
||||
if key == "" {
|
||||
key = os.Getenv("MAI_SUPABASE_KEY")
|
||||
}
|
||||
if key == "" {
|
||||
return nil, false
|
||||
}
|
||||
return &SupabaseSink{
|
||||
URL: u,
|
||||
APIKey: key,
|
||||
HTTP: &http.Client{Timeout: 10 * time.Second},
|
||||
}, true
|
||||
}
|
||||
|
||||
type supabaseRow struct {
|
||||
Backend string `json:"backend"`
|
||||
Model string `json:"model"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
PromptHash string `json:"prompt_hash"`
|
||||
LatencyMs int `json:"latency_ms"`
|
||||
CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
}
|
||||
|
||||
// Record inserts one row into mai.imagen_usage.
|
||||
func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error {
|
||||
body, err := json.Marshal(supabaseRow{
|
||||
Backend: row.Backend,
|
||||
Model: row.Model,
|
||||
Seed: row.Seed,
|
||||
PromptHash: row.PromptHash,
|
||||
LatencyMs: row.LatencyMs,
|
||||
CostUSDEstimate: row.CostUSDEstimate,
|
||||
Caller: row.Caller,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("usage: marshal: %w", err)
|
||||
}
|
||||
|
||||
endpoint := s.URL + "/rest/v1/imagen_usage"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept-Profile", supabaseSchema)
|
||||
req.Header.Set("Content-Profile", supabaseSchema)
|
||||
req.Header.Set("Prefer", "return=minimal")
|
||||
|
||||
resp, err := s.HTTP.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("usage: POST: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Row is the read-side row shape (only the fields the CLI needs).
|
||||
type Row struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Backend string `json:"backend"`
|
||||
Model string `json:"model"`
|
||||
Seed *int64 `json:"seed"`
|
||||
PromptHash string `json:"prompt_hash"`
|
||||
LatencyMs *int `json:"latency_ms"`
|
||||
CostUSDEstimate *float64 `json:"cost_usd_estimate"`
|
||||
Caller *string `json:"caller"`
|
||||
}
|
||||
|
||||
// Query returns rows from mai.imagen_usage filtered by created_at >= since.
|
||||
// Pass zero time to fetch the full table (capped server-side by PostgREST
|
||||
// — we set a hard 5000-row limit here too).
|
||||
func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) {
|
||||
q := url.Values{}
|
||||
q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller")
|
||||
q.Set("order", "created_at.desc")
|
||||
q.Set("limit", "5000")
|
||||
if !since.IsZero() {
|
||||
q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Accept-Profile", supabaseSchema)
|
||||
|
||||
resp, err := s.HTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("usage: GET: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
var rows []Row
|
||||
if err := json.Unmarshal(body, &rows); err != nil {
|
||||
return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body))
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func snip(b []byte) string {
|
||||
const max = 500
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > max {
|
||||
s = s[:max] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
24
scripts/comfyui.service
Normal file
24
scripts/comfyui.service
Normal file
@@ -0,0 +1,24 @@
|
||||
[Unit]
|
||||
Description=ComfyUI image generation server
|
||||
Documentation=https://github.com/comfyanonymous/ComfyUI
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=m
|
||||
Group=m
|
||||
WorkingDirectory=/home/m/dev/comfyui
|
||||
ExecStart=/home/m/dev/comfyui/.venv/bin/python /home/m/dev/comfyui/main.py \
|
||||
--listen 0.0.0.0 --port 8188 \
|
||||
--output-directory /home/m/dev/comfyui/output \
|
||||
--temp-directory /home/m/dev/comfyui/temp
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
TimeoutStopSec=30
|
||||
NoNewPrivileges=true
|
||||
PrivateTmp=true
|
||||
LimitNOFILE=65535
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
37
scripts/download-flux-schnell.sh
Executable file
37
scripts/download-flux-schnell.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
# Download FLUX.1 schnell + accompanying VAE/text encoders into a ComfyUI tree.
|
||||
# Uses ungated mirrors — the official Black-Forest-Labs repo is gated and
|
||||
# requires an HF token. See docs/setup-comfyui-mrock.md.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="${1:-$HOME/dev/comfyui/models}"
|
||||
|
||||
if [ ! -d "$ROOT" ]; then
|
||||
echo "models root $ROOT does not exist — pass it as the first argument" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "$ROOT/unet" "$ROOT/vae" "$ROOT/clip"
|
||||
|
||||
CKPT="https://huggingface.co/Comfy-Org/flux1-schnell/resolve/main/flux1-schnell.safetensors"
|
||||
VAE="https://huggingface.co/sirorable/flux-ae-vae/resolve/main/ae.safetensors"
|
||||
CLIP_L="https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors"
|
||||
T5="https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors"
|
||||
|
||||
dl() {
|
||||
local url=$1 dest=$2
|
||||
if [ -s "$dest" ]; then
|
||||
echo "skip $dest (already present)"
|
||||
return
|
||||
fi
|
||||
echo "downloading $url -> $dest"
|
||||
curl -L --fail --retry 3 --retry-delay 5 -C - -o "$dest" "$url"
|
||||
}
|
||||
|
||||
dl "$CKPT" "$ROOT/unet/flux1-schnell.safetensors"
|
||||
dl "$VAE" "$ROOT/vae/ae.safetensors"
|
||||
dl "$CLIP_L" "$ROOT/clip/clip_l.safetensors"
|
||||
dl "$T5" "$ROOT/clip/t5xxl_fp8_e4m3fn.safetensors"
|
||||
|
||||
echo "done"
|
||||
87
scripts/flux-schnell-poc.json
Normal file
87
scripts/flux-schnell-poc.json
Normal file
@@ -0,0 +1,87 @@
|
||||
{
|
||||
"prompt": {
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "a small fishbowl with a cat staring out, photo, soft light",
|
||||
"clip": ["11", 0]
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {
|
||||
"samples": ["31", 0],
|
||||
"vae": ["10", 0]
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {
|
||||
"filename_prefix": "imagen-poc",
|
||||
"images": ["8", 0]
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"class_type": "VAELoader",
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"class_type": "DualCLIPLoader",
|
||||
"inputs": {
|
||||
"clip_name1": "t5xxl_fp8_e4m3fn.safetensors",
|
||||
"clip_name2": "clip_l.safetensors",
|
||||
"type": "flux"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": {
|
||||
"unet_name": "flux1-schnell.safetensors",
|
||||
"weight_dtype": "fp8_e4m3fn"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": ["11", 0]
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
}
|
||||
},
|
||||
"30": {
|
||||
"class_type": "ModelSamplingFlux",
|
||||
"inputs": {
|
||||
"model": ["12", 0],
|
||||
"max_shift": 1.15,
|
||||
"base_shift": 0.5,
|
||||
"width": 1024,
|
||||
"height": 1024
|
||||
}
|
||||
},
|
||||
"31": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["30", 0],
|
||||
"seed": 1234567,
|
||||
"steps": 4,
|
||||
"cfg": 1.0,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"denoise": 1.0,
|
||||
"positive": ["6", 0],
|
||||
"negative": ["13", 0],
|
||||
"latent_image": ["27", 0]
|
||||
}
|
||||
}
|
||||
},
|
||||
"client_id": "imagen-poc-001"
|
||||
}
|
||||
Reference in New Issue
Block a user