11 Commits

Author SHA1 Message Date
mAi
2758c5a500 mAi: #8 - imagen.jobs queue + worker subcommand (flexsiebels write path)
Async write path for the flexsiebels owner-mode UI: flexsiebels INSERTs into
imagen.jobs, the worker on mRiver claims pending rows via LISTEN/NOTIFY +
5s safety poll, runs the same generate pipeline imagen generate uses, and
writes the result through internal/cloud into imagen.images.

- Schema migration imagen_jobs_init: table + status CHECK + two indexes +
  owner-scoped RLS + grants + AFTER INSERT trigger publishing on the
  imagen_jobs channel via pg_notify.
- internal/worker: DB-agnostic loop over a Queue interface. Drains the
  whole pending backlog on each wake. Job-scoped contexts are derived
  from Background so SIGTERM lets the in-flight generation finish (no
  half-state). ResetStaleRunning at startup unsticks rows left over from
  a previous crash. Eight unit tests cover the done / failed / missing-id /
  drain / NOTIFY-wake / shutdown / transient-error paths against a fake
  queue (no real Postgres in CI).
- cmd/imagen/worker.go: pgx-backed Queue (one dedicated conn for LISTEN +
  UPDATE), plus the workerPipeline that reuses buildBackend +
  attachUsageSink + prompt.Apply + buildWriter + maybeCloudSync. The
  per-job owner_user_id overrides the env-level fallback so each row in
  imagen.images is attributed correctly.
- maybeCloudSync now returns (*cloud.SyncResult, error) so the worker can
  link imagen.jobs.image_id to the inserted imagen.images row. The CLI
  generate path keeps printing its stderr summary unchanged.
- scripts/imagen-worker.service + .env.example for the systemd --user unit
  on mRiver. EnvironmentFile lives in ~/.dotfiles and is never committed.
- docs/setup-worker-mriver.md walks through installation + the spec's
  SQL-INSERT smoke; docs/architecture.md grows an "async write path"
  section.
- worker_integration_test.go (env-guarded by IMAGEN_WORKER_INTEGRATION=1)
  drives one real job through the full pipeline against msupabase using
  the mock backend, then verifies imagen.images + Storage object landed
  and the row flipped to done with image_id linked. Verified end-to-end:
  pickup latency ~7ms, total 74ms, failure path captures error text.
2026-05-11 10:23:33 +02:00
mAi
cb6656c436 Merge mai/hermes/issue-7-imagen-7-cloud: Supabase cloud-sync for flexsiebels viewer (#7) 2026-05-11 01:53:12 +02:00
mAi
e22f286024 mAi: #7 - cloud-sync to Supabase Storage + imagen.images
Every successful imagen generate now (a) uploads the PNG to the private
imagen-generated bucket and (b) inserts a row into imagen.images, the
data plane the flexsiebels owner-mode viewer reads from.

Schema, RLS, indexes, bucket and PostgREST exposure landed via four
applied migrations on msupabase: imagen_schema_init,
imagen_schema_grants, imagen_storage_policies, imagen_pgrst_expose
(authenticator role-level ALTER + reload). Owner UUID for m:
ac6c9501-3757-4a6d-8b97-2cff4288382b — documented in the config sample.

Code: new internal/cloud/ package mirroring the internal/usage/ shape.
PostgREST POST against the imagen schema (Accept-Profile + Content-
Profile headers), Storage upload via PUT with x-upsert, retry on 5xx /
transport but not 4xx, owner_user_id required (the column is NOT NULL
and the read-side RLS policy needs it).

Wiring in cmd/imagen/generate.go: --no-cloud flag, output.cloud_sync
config knob (auto|on|off mirroring --preview), $IMAGEN_CLOUD_SYNC env
override. The hook reads the just-written PNG + sidecar from disk and
calls cloud.Sync; failures emit "imagen: cloud sync: <err>" to stderr
without changing exit code, so a Supabase blip never loses the artefact.
output.Outputs grew Date/Slug/Seed fields so storage_path mirrors the
local filename's prefix exactly (no UTC-vs-local drift).

Config: owner_user_id field added; sample comment points at the
auth.users lookup. imagen config validate warns on stderr when
cloud_sync is on/auto but owner_user_id is empty.

Tests: cloud_test.go covers happy path, retry-on-5xx, no-retry-on-4xx,
missing-owner-uuid, missing-date-or-slug, signed URL, and the partial-
success case where the upload landed but the DB insert failed.
generate_test.go covers the precedence chain for cloud-sync mode
resolution. Build + tests clean across the tree.

Real smoke against mRock: generation through flux-schnell-local writes
the local PNG + sidecar AND uploads to imagen-generated/2026-05-11/...
AND inserts into imagen.images. Signed URL round-trips the same bytes.
--no-cloud verified to skip both Storage and DB.
2026-05-11 01:51:09 +02:00
mAi
2d5896e27d Merge mai/hermes/issue-3-imagen-3: Replicate API backend + cost-tracking + usage CLI (#3) 2026-05-08 17:32:09 +02:00
mAi
b282325663 mAi: #3 - Replicate adapter, mai.imagen_usage cost-tracking, usage CLI
Implements the Replicate API backend (FLUX schnell / FLUX dev) per ImaGen
issue #3:

- internal/backend/replicate.go — Backend adapter. Supports model
  refs as "owner/name" (uses /v1/models/{owner}/{name}/predictions) and
  "owner/name:hash" (uses /v1/predictions with explicit version). Polls
  /v1/predictions/{id} every 500ms with model-aware timeout (60s schnell,
  120s dev). Resilience: 401 names api_token_env, 429 with exp backoff
  up to 3 retries (honours Retry-After), 5xx retries once, image
  download retries once on transient failure.
- internal/backend/replicate_pricing.go — hardcoded per-image USD rates
  for known FLUX models, snapshotted from replicate.com/pricing with a
  refresh TODO.
- internal/backend/replicate_test.go — mocked-HTTP unit tests covering
  happy path (model + version-pinned), 401, 429 retry policy, failed
  prediction, poll timeout, image-download retry, ctx cancel, BackendOpts
  passthrough, default_steps, aspect-ratio reduction, sha256 prompt hash.
- internal/usage/usage.go — Supabase REST sink + read-side query for
  mai.imagen_usage. Adapter writes are best-effort: failures warn but
  the image still lands.
- cmd/imagen/usage.go — `imagen usage [--since DATE] [--raw]` reads
  the table and prints a tab-aligned grouped or raw table with totals.
- cmd/imagen/backends.go — instances of type=replicate now report
  "ok" or "not configured (set REPLICATE_API_TOKEN)" depending on env.
- internal/config/config.go — sample adds flux-schnell-replicate +
  flux-dev-replicate; default_backend stays flux-schnell-local.
- Supabase migration mai.imagen_usage (id, created_at, backend, model,
  seed, prompt_hash, latency_ms, cost_usd_estimate, caller) + indexes
  on (created_at DESC) and (caller). The raw prompt is never stored.

Caller identity resolves from MAI_FROM_ID, then the tmux pane's
@mai-name option, mirroring the maimcp identity logic. Prompt hash is
sha256 of the user-facing prompt; raw prompt never reaches the table.
2026-05-08 17:28:29 +02:00
mAi
a1d0165445 Merge mai/hermes/issue-5-imagen-5-tmux: tmux-window preview for generate (#5) 2026-05-08 17:12:57 +02:00
mAi
2a8bd4313b mAi: #5 - tmux-window preview for generate
Adds an optional `imagen generate` post-step that opens a sibling
tmux window running tmux-img --hold <path>.

- internal/preview: Mode (auto|on|off), Resolve, and a Spawner that
  shells out to tmux new-window. Typed errors for missing tmux,
  missing tmux-img, and "preview forced on outside $TMUX".
- cmd/imagen/generate: --preview / --no-preview flags plus
  $IMAGEN_PREVIEW. Resolution chain: config -> env -> flag.
  auto requires both stdout-is-tty and $TMUX. Failures are
  warnings - the image is already on disk.
- internal/config: output.preview field, validated to auto|on|off,
  threaded into the sample.
- Tests for ParseMode, Resolve, Spawn argv (incl. shell quoting of
  paths with apostrophes), missing-binary errors, and the CLI
  resolution table.
- Docs (usage + architecture) updated.

/imagine SKILL.md edit lives in dotfiles - deferred to coordinate
with #4.
2026-05-08 17:09:59 +02:00
mAi
4183d4c55a Merge mai/hermes/issue-2-imagen-2-comfyui: ComfyUI/FLUX schnell on mRock + Go adapter (#2) 2026-05-08 17:01:02 +02:00
mAi
127bbf3ed5 mAi: #2 - phase 2 ComfyUI Go adapter, tests, config sample
internal/backend/comfyui.go implements the Backend interface against
ComfyUI's /prompt + /history + /view HTTP API. Workflow is the canonical
FLUX.1 schnell shape — UNETLoader + DualCLIPLoader (clip_l + t5xxl fp8) +
VAELoader + ModelSamplingFlux + KSampler — assembled as a Go map per
request so Width / Height / Seed / Steps / sampler / scheduler all flow
into the right node inputs.

Resilience: one retry on /prompt 5xx and transient network errors, no
retry on 4xx. Connection-refused / timeouts surface a 'boot-whitetower
mrock' hint. node_errors mentioning a missing unet point users at
docs/setup-comfyui-mrock.md (matches both the 4xx and 200-with-errors
shapes ComfyUI uses across versions).

Result.Metadata carries model, seed_used, latency_ms, steps, sampler,
scheduler, width, height, prompt_id, client_id, plus best-effort
vram_used_mib pulled from /system_stats post-gen.

Tests use httptest with poll interval squashed to 1ms — no real mRock
dependency. Coverage: happy path, defaults, retry-once on 5xx, give-up
after two 5xx, no-retry on 4xx, missing-model hint (both 4xx and
200+node_errors paths), history-error surfaced, /view 4xx, unreachable
host, ctx cancel during poll, workflow-shape assertion, registration.

Config sample: flux-schnell-local is now default_backend; the user-facing
block names the unet file by basename (the mapping into models/unet/ is
the server's convention, captured in docs/setup-comfyui-mrock.md from
phase 1).

Smoke verified end-to-end: imagen generate ... --backend
flux-schnell-local --size 1024x1024 --output /tmp/cat-via-cli.png on
mRock returned a 1024x1024 PNG of a cat in a fishbowl in 10.3s with a
sidecar carrying seed + latency_ms + the rest of the metadata.
2026-05-08 16:59:21 +02:00
mAi
a24ac2826f mAi: #2 - phase 1 PoC: ComfyUI on mRock + first FLUX schnell image
Native systemd install (matches Ollama pattern on Arch — Docker on mRock
has no nvidia runtime; native venv via uv is the lighter path). The
Black-Forest-Labs FLUX.1-schnell HF repo is gated, so the download script
points at ungated mirrors (Comfy-Org/flux1-schnell + sirorable/flux-ae-vae)
that ship the same Apache-2.0 weights.

First image — cat in a fishbowl, 1024x1024, 4 steps — generated end-to-end
in 9.79s via curl + workflow JSON; stored at
/home/m/dev/ImaGen/poc/first-image.png on mRiver (not committed; transient
PoC artefact). Go adapter is phase 2.
2026-05-08 16:50:16 +02:00
mAi
20490913c1 Merge mai/bohr/issue-211-bootstrap: framework skeleton (#211) 2026-05-08 14:37:24 +02:00
36 changed files with 5783 additions and 34 deletions

1
.gitignore vendored
View File

@@ -7,3 +7,4 @@
.env.local
/imagen
/coverage.txt
/.m/

View File

@@ -10,13 +10,17 @@ and lifecycle of its own block in `~/.config/imagen.yaml`.
## Architecture
```
cmd/imagen/ CLI shell — generate, backends, config, serve
cmd/imagen/ CLI shell — generate, worker, backends, config, serve
internal/backend/ Backend interface + Registry + Mock reference impl
internal/prompt/ Style preset registry (embedded styles.yaml)
internal/output/ Filename templating, image writer, JSON sidecar
internal/config/ YAML loader, validation, sample generator
internal/cloud/ Supabase Storage + imagen.images writer
internal/usage/ mai.imagen_usage cost-tracking sink
internal/worker/ imagen.jobs queue consumer (DB-agnostic via Queue interface)
internal/server/ HTTP stub (not implemented yet — follow-up issue)
docs/ architecture.md, usage.md
scripts/ imagen-worker.service + env template, ComfyUI scripts
docs/ architecture.md, usage.md, setup-worker-mriver.md
```
Data flow for `imagen generate`:

View File

@@ -10,6 +10,27 @@ import (
"mgit.msbls.de/m/ImaGen/internal/config"
)
// instanceStatus checks adapter-specific preconditions (e.g. the
// Replicate API token env var being set) and returns a short
// user-facing status string.
func instanceStatus(spec config.BackendSpec) string {
if !backend.Default.Has(spec.Type) {
return fmt.Sprintf("type %q not compiled in", spec.Type)
}
switch spec.Type {
case backend.ReplicateType:
envName, _ := spec.Raw["api_token_env"].(string)
if envName == "" {
envName = "REPLICATE_API_TOKEN"
}
if os.Getenv(envName) == "" {
return fmt.Sprintf("not configured (set %s)", envName)
}
return "ok"
}
return "registered"
}
func runBackends(args []string) error {
fs := flag.NewFlagSet("backends", flag.ContinueOnError)
var configPath string
@@ -27,10 +48,7 @@ func runBackends(args []string) error {
fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS")
if cfg != nil {
for name, spec := range cfg.Backends {
status := "registered"
if !backend.Default.Has(spec.Type) {
status = fmt.Sprintf("type %q not compiled in", spec.Type)
}
status := instanceStatus(spec)
marker := ""
if name == cfg.DefaultBackend {
marker = " (default)"

View File

@@ -39,6 +39,18 @@ func runConfig(args []string) error {
}
fmt.Fprintf(os.Stdout, "OK — %d backend(s) defined, default=%q\n",
len(cfg.Backends), cfg.DefaultBackend)
// Soft warnings — surfaced on stderr so they're visible but don't
// fail the validate exit code.
cloudMode := cfg.Output.CloudSync
if cloudMode == "" {
cloudMode = "auto"
}
if cloudMode != "off" && cfg.OwnerUserID == "" {
fmt.Fprintln(os.Stderr,
"warning: cloud_sync is "+cloudMode+" but owner_user_id is empty — DB inserts will be skipped.")
fmt.Fprintln(os.Stderr,
" look it up: SELECT id FROM auth.users WHERE email = '<your-supabase-email>';")
}
return nil
default:
return userErr("unknown config subcommand %q (init|validate|path)", args[0])

View File

@@ -2,30 +2,39 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/cloud"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/preview"
"mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
func runGenerate(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("generate", flag.ContinueOnError)
var (
backendName string
size string
outPath string
seed int64
steps int
style string
negative string
configPath string
noSidecar bool
backendName string
size string
outPath string
seed int64
steps int
style string
negative string
configPath string
noSidecar bool
previewOn bool
previewOff bool
noCloud bool
)
fs.StringVar(&backendName, "backend", "", "backend instance name (default: config.default_backend)")
fs.StringVar(&size, "size", "1024x1024", "WxH, e.g. 1024x1024")
@@ -36,6 +45,9 @@ func runGenerate(ctx context.Context, args []string) error {
fs.StringVar(&negative, "negative", "", "negative prompt (ignored by backends that don't support it)")
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.BoolVar(&noSidecar, "no-sidecar", false, "skip the JSON sidecar even if config enables it")
fs.BoolVar(&previewOn, "preview", false, "force tmux preview window on (errors outside $TMUX)")
fs.BoolVar(&previewOff, "no-preview", false, "skip the tmux preview window")
fs.BoolVar(&noCloud, "no-cloud", false, "skip Supabase upload + imagen.images insert for this generation")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen generate "<prompt>" [flags]`)
fs.PrintDefaults()
@@ -76,6 +88,7 @@ func runGenerate(ctx context.Context, args []string) error {
if err != nil {
return err
}
attachUsageSink(be)
finalPrompt, err := prompt.Apply(rawPrompt, style)
if err != nil {
@@ -118,9 +131,245 @@ func runGenerate(ctx context.Context, args []string) error {
if paths.SidecarPath != "" {
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
}
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", paths, in, res, w, h); err != nil {
// cloud-sync failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
} else if result != nil && result.ImageID != "" {
fmt.Fprintf(os.Stderr, "cloud: imagen.images.id=%s storage_path=%s\n", result.ImageID, result.StoragePath)
}
if err := maybePreview(cfg, previewOn, previewOff, paths.ImagePath, rawPrompt); err != nil {
// preview failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: preview:", err)
}
return nil
}
// resolveCloudSyncMode applies the precedence chain config -> env -> flag.
// Flags win, env beats config, config beats the implicit auto default.
// Mirrors resolvePreviewMode shape.
func resolveCloudSyncMode(cfg *config.Config, noCloudFlag bool, env string) (string, error) {
mode := "auto"
if cfg != nil && cfg.Output.CloudSync != "" {
mode = cfg.Output.CloudSync
}
if env != "" {
switch env {
case "auto", "on", "off":
mode = env
default:
return "", fmt.Errorf("$IMAGEN_CLOUD_SYNC = %q (must be auto|on|off)", env)
}
}
if noCloudFlag {
mode = "off"
}
return mode, nil
}
// maybeCloudSync resolves the effective mode and, if it says yes, uploads
// the PNG and inserts the row. Returns the SyncResult on success so callers
// that need the imagen.images.id (e.g. the worker linking a job row) can pick
// it up. ownerOverride, when non-empty, wins over config + env — the worker
// passes the job row's owner_user_id so each job is attributed correctly.
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
if err != nil {
return nil, err
}
if mode == "off" {
return nil, nil
}
sink, ok := cloud.NewFromEnv()
if !ok {
if mode == "on" {
return nil, fmt.Errorf("cloud_sync=on but SUPABASE_URL / SUPABASE_SERVICE_KEY not set in env")
}
// auto + missing env = silent skip.
return nil, nil
}
switch {
case ownerOverride != "":
sink.OwnerUserID = ownerOverride
case cfg != nil && cfg.OwnerUserID != "":
// Config-supplied owner_user_id takes precedence over $IMAGEN_OWNER_USER_ID.
sink.OwnerUserID = cfg.OwnerUserID
}
if sink.OwnerUserID == "" {
if mode == "on" {
return nil, fmt.Errorf("cloud_sync=on but owner_user_id not set in config and $IMAGEN_OWNER_USER_ID is empty")
}
// auto + missing UUID = silent skip.
return nil, nil
}
pngBytes, readErr := os.ReadFile(paths.ImagePath)
if readErr != nil {
return nil, fmt.Errorf("read local image: %w", readErr)
}
// Reuse the writer's date/slug/seed so storage_path mirrors the local
// filename's prefix exactly — viewers can join `imagen.images` on
// either side without timezone drift.
date := paths.Date
slug := paths.Slug
if date == "" || slug == "" {
now := time.Now()
date = now.Format("2006-01-02")
slug = output.Slug(in.Prompt)
}
ext := in.Ext
if ext == "" {
ext = strings.TrimPrefix(filepath.Ext(paths.ImagePath), ".")
}
if ext == "" {
ext = "png"
}
// Snapshot the sidecar (if it exists) so the row carries the same
// metadata view a downstream viewer would see on disk.
var sidecar map[string]any
if paths.SidecarPath != "" {
if scBytes, err := os.ReadFile(paths.SidecarPath); err == nil {
_ = json.Unmarshal(scBytes, &sidecar)
}
}
model := metaString(res.Metadata, "model")
steps := metaInt(res.Metadata, "steps")
cost := metaFloatPtr(res.Metadata, "cost_usd_estimate")
latency := metaInt(res.Metadata, "latency_ms")
seed := paths.Seed
if seed == 0 {
seed = in.Seed
}
syncReq := cloud.SyncRequest{
Date: date,
Slug: slug,
Seed: seed,
Ext: ext,
PNG: pngBytes,
MimeType: res.MimeType,
Prompt: in.Prompt,
Backend: in.Backend,
Model: model,
Steps: steps,
Width: width,
Height: height,
LatencyMs: latency,
CostUSDEstimate: cost,
Sidecar: sidecar,
}
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
return sink.Sync(syncCtx, syncReq)
}
func metaString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func metaInt(m map[string]any, key string) int {
v, ok := m[key]
if !ok {
return 0
}
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 0
}
func metaFloatPtr(m map[string]any, key string) *float64 {
v, ok := m[key]
if !ok {
return nil
}
switch n := v.(type) {
case float64:
return &n
case float32:
f := float64(n)
return &f
case int:
f := float64(n)
return &f
case int64:
f := float64(n)
return &f
}
return nil
}
// resolvePreviewMode applies the precedence chain config -> env -> flag.
// Flags win, env beats config, config beats the implicit auto default.
func resolvePreviewMode(cfg *config.Config, flagOn, flagOff bool, env string) (preview.Mode, error) {
mode := preview.ModeAuto
if cfg != nil && cfg.Output.Preview != "" {
m, err := preview.ParseMode(cfg.Output.Preview)
if err != nil {
return "", fmt.Errorf("config output.preview: %w", err)
}
mode = m
}
if env != "" {
m, err := preview.ParseMode(env)
if err != nil {
return "", fmt.Errorf("$IMAGEN_PREVIEW: %w", err)
}
mode = m
}
if flagOn && flagOff {
return "", userErr("--preview and --no-preview are mutually exclusive")
}
if flagOn {
mode = preview.ModeOn
}
if flagOff {
mode = preview.ModeOff
}
return mode, nil
}
// maybePreview resolves the effective preview mode and, if it says yes,
// spawns a tmux window via tmux-img. Always non-fatal.
func maybePreview(cfg *config.Config, flagOn, flagOff bool, imagePath, rawPrompt string) error {
mode, err := resolvePreviewMode(cfg, flagOn, flagOff, os.Getenv("IMAGEN_PREVIEW"))
if err != nil {
return err
}
decision, err := preview.Resolve(mode, os.Getenv("TMUX") != "", stdoutIsTTY())
if err != nil {
return err
}
if !decision.ShouldPreview {
return nil
}
spawner := &preview.Spawner{}
return spawner.Spawn(imagePath, output.Slug(rawPrompt))
}
func stdoutIsTTY() bool {
fi, err := os.Stdout.Stat()
if err != nil {
return false
}
return fi.Mode()&os.ModeCharDevice != 0
}
// splitLeadingPositional separates the positional args at the start of args
// from the rest (which begins with the first flag). A literal "--" terminator
// pushes everything after it into the positional list and out of flag parsing.
@@ -153,6 +402,21 @@ func parseSize(s string) (int, int, error) {
return w, h, nil
}
// attachUsageSink wires a Supabase cost-tracking sink into the backend
// when it accepts one and the env is configured. Adapters that record
// usage expose a public Sink field of type backend.UsageSink.
func attachUsageSink(be backend.Backend) {
r, ok := be.(*backend.Replicate)
if !ok {
return
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return
}
r.Sink = sink
}
func buildBackend(cfg *config.Config, name string) (backend.Backend, error) {
if cfg != nil {
spec, ok := cfg.Backends[name]

View File

@@ -0,0 +1,87 @@
package main
import (
"testing"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/preview"
)
func TestResolvePreviewMode(t *testing.T) {
type tc struct {
name string
cfg *config.Config
flagOn bool
flagOff bool
env string
want preview.Mode
wantError bool
}
cases := []tc{
{name: "all-empty-defaults-to-auto", want: preview.ModeAuto},
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, want: preview.ModeOn},
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{Preview: "off"}}, want: preview.ModeOff},
{name: "config-auto-explicit", cfg: &config.Config{Output: config.OutputConfig{Preview: "auto"}}, want: preview.ModeAuto},
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, env: "off", want: preview.ModeOff},
{name: "flag-on-overrides-env-off", env: "off", flagOn: true, want: preview.ModeOn},
{name: "flag-off-overrides-env-on", env: "on", flagOff: true, want: preview.ModeOff},
{name: "flag-off-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, flagOff: true, want: preview.ModeOff},
{name: "both-flags-error", flagOn: true, flagOff: true, wantError: true},
{name: "bad-env-errors", env: "yes", wantError: true},
{name: "bad-config-errors", cfg: &config.Config{Output: config.OutputConfig{Preview: "yes"}}, wantError: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := resolvePreviewMode(c.cfg, c.flagOn, c.flagOff, c.env)
if c.wantError {
if err == nil {
t.Fatalf("expected error, got mode %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != c.want {
t.Errorf("mode = %q, want %q", got, c.want)
}
})
}
}
func TestResolveCloudSyncMode(t *testing.T) {
type tc struct {
name string
cfg *config.Config
noCloud bool
env string
want string
wantError bool
}
cases := []tc{
{name: "all-empty-defaults-to-auto", want: "auto"},
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, want: "on"},
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "off"}}, want: "off"},
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "off", want: "off"},
{name: "flag-overrides-env-and-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "on", noCloud: true, want: "off"},
{name: "flag-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, noCloud: true, want: "off"},
{name: "bad-env-errors", env: "yes", wantError: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := resolveCloudSyncMode(c.cfg, c.noCloud, c.env)
if c.wantError {
if err == nil {
t.Fatalf("expected error, got mode %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != c.want {
t.Errorf("mode = %q, want %q", got, c.want)
}
})
}
}

View File

@@ -14,14 +14,16 @@ import (
_ "mgit.msbls.de/m/ImaGen/internal/backend"
)
const usage = `imagen — model-agnostic image generation
const helpText = `imagen — model-agnostic image generation
Usage:
imagen generate <prompt> [flags] generate one image
imagen worker [flags] consume the imagen.jobs queue (daemon)
imagen backends list registered backend types
imagen config init print a sample imagen.yaml on stdout
imagen config validate validate the active config
imagen serve [--addr :8080] (stub) start the HTTP server
imagen usage [--since DATE] show cost-tracking rows
imagen version print version
imagen help show this help
@@ -33,7 +35,7 @@ var Version = "dev"
func main() {
if len(os.Args) < 2 {
fmt.Fprint(os.Stderr, usage)
fmt.Fprint(os.Stderr, helpText)
os.Exit(2)
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
@@ -44,18 +46,22 @@ func main() {
switch os.Args[1] {
case "generate":
err = runGenerate(ctx, args)
case "worker":
err = runWorker(ctx, args)
case "backends":
err = runBackends(args)
case "config":
err = runConfig(args)
case "serve":
err = runServe(args)
case "usage":
err = runUsage(ctx, args)
case "version", "-v", "--version":
fmt.Println(Version)
case "help", "-h", "--help":
fmt.Print(usage)
fmt.Print(helpText)
default:
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage)
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText)
os.Exit(2)
}
if err != nil {

189
cmd/imagen/usage.go Normal file
View File

@@ -0,0 +1,189 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"sort"
"strings"
"text/tabwriter"
"time"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage
// via Supabase REST and prints a tab-aligned table grouped by week +
// backend + model + caller, with totals at the bottom.
func runUsage(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("usage", flag.ContinueOnError)
var (
since string
raw bool
)
fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date")
fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]")
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
var sinceT time.Time
if since != "" {
t, err := time.Parse("2006-01-02", since)
if err != nil {
return userErr("--since must be YYYY-MM-DD: %v", err)
}
sinceT = t
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage")
}
rows, err := sink.Query(ctx, sinceT)
if err != nil {
return err
}
if raw {
printRawRows(rows)
return nil
}
printGroupedRows(rows)
return nil
}
func printRawRows(rows []usage.Row) {
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD")
var totalCost float64
for _, r := range rows {
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n",
r.CreatedAt.Local().Format("2006-01-02 15:04"),
r.Backend,
r.Model,
derefString(r.Caller),
intOrDash(r.LatencyMs),
costOrDash(r.CostUSDEstimate),
)
if r.CostUSDEstimate != nil {
totalCost += *r.CostUSDEstimate
}
}
fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost)
_ = tw.Flush()
}
type group struct {
week string
backend string
model string
caller string
count int
cost float64
costSet bool
}
type groupKey struct {
week, backend, model, caller string
}
func printGroupedRows(rows []usage.Row) {
groups := map[groupKey]*group{}
for _, r := range rows {
caller := derefString(r.Caller)
k := groupKey{
week: weekStart(r.CreatedAt).Format("2006-01-02"),
backend: r.Backend,
model: r.Model,
caller: caller,
}
g, ok := groups[k]
if !ok {
g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller}
groups[k] = g
}
g.count++
if r.CostUSDEstimate != nil {
g.cost += *r.CostUSDEstimate
g.costSet = true
}
}
keys := make([]groupKey, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].week != keys[j].week {
return keys[i].week > keys[j].week // newest first
}
if keys[i].backend != keys[j].backend {
return keys[i].backend < keys[j].backend
}
if keys[i].model != keys[j].model {
return keys[i].model < keys[j].model
}
return keys[i].caller < keys[j].caller
})
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD")
var totalCount int
var totalCost float64
for _, k := range keys {
g := groups[k]
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n",
g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet),
)
totalCount += g.count
totalCost += g.cost
}
fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost)
_ = tw.Flush()
}
// weekStart returns the Monday of the week containing t (UTC).
func weekStart(t time.Time) time.Time {
t = t.UTC()
wd := int(t.Weekday())
if wd == 0 {
wd = 7 // shift Sunday to end-of-week
}
delta := time.Duration(wd-1) * -24 * time.Hour
d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
return d.Add(delta)
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
func intOrDash(p *int) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%d", *p)
}
func costOrDash(p *float64) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%.4f", *p)
}
func costStr(v float64, set bool) string {
if !set {
return "-"
}
return strings.TrimSpace(fmt.Sprintf("%.4f", v))
}

287
cmd/imagen/worker.go Normal file
View File

@@ -0,0 +1,287 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/jackc/pgx/v5"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/worker"
)
// runWorker is the `imagen worker` subcommand: a long-running daemon that
// consumes the imagen.jobs queue and writes results into imagen.images via
// the same cloud-sync path generate uses.
func runWorker(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("worker", flag.ContinueOnError)
var (
configPath string
pollInterval time.Duration
jobTimeout time.Duration
)
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.DurationVar(&pollInterval, "poll-interval", 5*time.Second, "safety-poll cadence between LISTEN wakeups")
fs.DurationVar(&jobTimeout, "job-timeout", 5*time.Minute, "max wall-time per job before the worker marks it failed")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen worker [flags]
Long-running daemon. LISTENs on the Postgres 'imagen_jobs' channel and polls
imagen.jobs every --poll-interval as a safety net, claims pending rows, runs
the generation pipeline, then updates the row with status + image_id.
Env:
IMAGEN_WORKER_DATABASE_URL Postgres DSN for direct LISTEN + UPDATE.
Required (PostgREST cannot LISTEN).
SUPABASE_URL, SUPABASE_SERVICE_KEY, IMAGEN_OWNER_USER_ID
Reused from generate's cloud-sync path; the
worker writes imagen.images rows through the
same code path. Per-job owner_user_id from the
job row overrides IMAGEN_OWNER_USER_ID.`)
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
cfg, cfgErr := config.Load(configPath)
if cfgErr != nil && !os.IsNotExist(cfgErr) {
return cfgErr
}
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
if dsn == "" {
return userErr("IMAGEN_WORKER_DATABASE_URL not set; the worker needs a direct Postgres DSN for LISTEN/NOTIFY")
}
q, err := dialQueue(ctx, dsn)
if err != nil {
return fmt.Errorf("queue: %w", err)
}
defer q.Close()
p := &workerPipeline{cfg: cfg}
w := worker.New(q, p, worker.Config{
PollInterval: pollInterval,
JobTimeout: jobTimeout,
Logger: func(format string, a ...any) { fmt.Fprintf(os.Stderr, format+"\n", a...) },
})
fmt.Fprintln(os.Stderr, "imagen worker: ready (poll-interval", pollInterval, "job-timeout", jobTimeout, ")")
return w.Run(ctx)
}
// pgxQueue is the production Queue. It opens one dedicated connection used
// for both LISTEN (long-lived) and UPDATE operations. A second connection
// would split state needlessly — a single worker process processes one job
// at a time so the connection is never contended.
type pgxQueue struct {
conn *pgx.Conn
}
func dialQueue(ctx context.Context, dsn string) (*pgxQueue, error) {
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("pgx.Connect: %w", err)
}
if _, err := conn.Exec(ctx, "LISTEN imagen_jobs"); err != nil {
conn.Close(ctx)
return nil, fmt.Errorf("LISTEN imagen_jobs: %w", err)
}
return &pgxQueue{conn: conn}, nil
}
func (q *pgxQueue) Close() {
if q == nil || q.conn == nil {
return
}
// Best-effort: a 5s budget is enough to send a polite TerminateMessage.
shutdown, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = q.conn.Close(shutdown)
}
// ClaimNextPending atomically marks the oldest pending row 'running' and
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
// process — out of scope for v1 but cheap insurance.
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
const stmt = `
UPDATE imagen.jobs
SET status='running', started_at=now()
WHERE id = (
SELECT id FROM imagen.jobs
WHERE status='pending'
ORDER BY created_at
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING id, owner_user_id, prompt, backend,
COALESCE(model,''),
COALESCE(width, 0), COALESCE(height, 0),
COALESCE(steps, 0), COALESCE(seed, 0),
COALESCE(style,'')`
var j worker.Job
err := q.conn.QueryRow(ctx, stmt).Scan(
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
return &j, nil
}
func (q *pgxQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='done', image_id=$2, completed_at=now() WHERE id=$1`,
jobID, imageID)
return err
}
func (q *pgxQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
// Trim outrageously long error text so a 10MB stack-trace doesn't end up
// in the row (callers see a summary, full text goes to stderr / logs).
const maxLen = 2000
if len(msg) > maxLen {
msg = msg[:maxLen] + "... [truncated]"
}
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='failed', error=$2, completed_at=now() WHERE id=$1`,
jobID, msg)
return err
}
// WaitForJob blocks until a NOTIFY arrives on imagen_jobs, the timeout fires,
// or ctx is cancelled. Notifications during a previous processJob are queued
// by pgx and delivered on the next call — we don't lose wake-ups even when
// processing took longer than poll-interval.
func (q *pgxQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
_, err := q.conn.WaitForNotification(waitCtx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil // poll cadence fired
}
if errors.Is(err, context.Canceled) {
return context.Canceled
}
return err
}
return nil
}
// ResetStaleRunning bumps any rows stuck in 'running' back to 'pending' so
// they get re-claimed. Called once at startup. A row stuck in 'running' came
// from a previous worker crash; without this, flexsiebels would poll
// forever on a job nobody is processing.
func (q *pgxQueue) ResetStaleRunning(ctx context.Context) error {
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='pending', started_at=NULL WHERE status='running'`)
return err
}
// workerPipeline is the Pipeline implementation that drives a single job
// through buildBackend → prompt enrichment → generate → write disk →
// cloud-sync, then returns the imagen.images.id back to the worker so it
// can link the row.
type workerPipeline struct {
cfg *config.Config
}
func (p *workerPipeline) Run(ctx context.Context, job worker.Job) worker.Outcome {
if job.OwnerUserID == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: missing owner_user_id", job.ID)}
}
if job.Prompt == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: empty prompt", job.ID)}
}
if job.Backend == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: missing backend", job.ID)}
}
be, err := buildBackend(p.cfg, job.Backend)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("backend %q: %w", job.Backend, err)}
}
attachUsageSink(be)
finalPrompt, err := prompt.Apply(job.Prompt, job.Style)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("style: %w", err)}
}
req := backend.Request{
Prompt: finalPrompt,
Width: job.Width,
Height: job.Height,
Steps: job.Steps,
Seed: job.Seed,
Style: job.Style,
}
res, err := be.Generate(ctx, req)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("generate: %w", err)}
}
defer res.ImageReader.Close()
writer := buildWriter(p.cfg, false)
in := output.Inputs{
Prompt: job.Prompt,
Backend: be.Name(),
Seed: seedFromMetadata(res.Metadata, job.Seed),
Ext: extFromMime(res.MimeType),
Metadata: res.Metadata,
}
paths, err := writer.Write(res.ImageReader, in)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("write disk: %w", err)}
}
// Worker is queue-driven: cloud-sync is mandatory because flexsiebels
// needs imagen.images.id to render the result. Pass cloud_sync=on via
// the override path (third arg = ownerUserID); we set the mode by
// disallowing the 'off' branch through the cfg later if the user
// explicitly turned it off in config.
if cloudModeOff(p.cfg) {
// We refuse to silently drop a queued job. If cloud sync is off in
// config, the worker can't serve flexsiebels at all.
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
}
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
if syncErr != nil {
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
}
if syncRes == nil || syncRes.ImageID == "" {
return worker.Outcome{Err: fmt.Errorf("cloud sync returned no imagen.images id (check SUPABASE_URL + SUPABASE_SERVICE_KEY)")}
}
return worker.Outcome{ImageID: syncRes.ImageID}
}
func cloudModeOff(cfg *config.Config) bool {
if cfg == nil {
return false
}
return strings.EqualFold(cfg.Output.CloudSync, "off")
}
// dimOrFallback returns job.<dim> when the job specified one, otherwise the
// dimension reported by the backend's metadata. Some backends (Replicate
// when given an aspect ratio) round the requested size to their nearest
// supported value; this keeps the row honest about what was actually generated.
func dimOrFallback(jobDim int, res *backend.Result, key string) int {
if jobDim > 0 {
return jobDim
}
return metaInt(res.Metadata, key)
}

View File

@@ -0,0 +1,129 @@
package main
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/worker"
)
// TestWorker_Integration_EndToEnd runs the full pipeline against a real
// msupabase instance: insert a row into imagen.jobs, let the worker claim
// it, generate via the mock backend (no Replicate spend, no ComfyUI
// dependency), write to Supabase Storage + imagen.images, then flip the job
// to 'done' with the linked image_id.
//
// Guarded by IMAGEN_WORKER_INTEGRATION=1. Required env beyond that:
//
// IMAGEN_WORKER_DATABASE_URL postgres DSN (direct, not PostgREST)
// SUPABASE_URL e.g. https://supa.flexsiebels.de
// SUPABASE_SERVICE_KEY service-role JWT
// IMAGEN_OWNER_USER_ID UUID of an auth.users row (RLS fallback)
//
// The test creates and later deletes its own job row so repeated runs don't
// leave debris.
func TestWorker_Integration_EndToEnd(t *testing.T) {
if os.Getenv("IMAGEN_WORKER_INTEGRATION") != "1" {
t.Skip("set IMAGEN_WORKER_INTEGRATION=1 to run the integration test")
}
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
if dsn == "" {
t.Fatal("IMAGEN_WORKER_DATABASE_URL must be set for the integration test")
}
if os.Getenv("SUPABASE_URL") == "" || os.Getenv("SUPABASE_SERVICE_KEY") == "" {
t.Fatal("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set for the integration test")
}
owner := os.Getenv("IMAGEN_OWNER_USER_ID")
if owner == "" {
t.Fatal("IMAGEN_OWNER_USER_ID must be set for the integration test")
}
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
q, err := dialQueue(ctx, dsn)
if err != nil {
t.Fatalf("dialQueue: %v", err)
}
defer q.Close()
// Insert the test job on a separate connection (the worker's conn is
// busy LISTENing). Mock backend = no external dependency.
insertConn, err := pgx.Connect(ctx, dsn)
if err != nil {
t.Fatalf("insert conn: %v", err)
}
defer insertConn.Close(ctx)
var jobID string
prompt := fmt.Sprintf("imagen integration test %d", time.Now().UnixNano())
err = insertConn.QueryRow(ctx, `
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
VALUES ($1, $2, 'mock', 64, 64)
RETURNING id`,
owner, prompt).Scan(&jobID)
if err != nil {
t.Fatalf("insert job: %v", err)
}
t.Logf("inserted imagen.jobs id=%s", jobID)
// Tidy up at the end of the test so a re-run starts clean.
defer func() {
cleanup, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _ = insertConn.Exec(cleanup, `DELETE FROM imagen.jobs WHERE id=$1`, jobID)
}()
// Use a per-test temp dir so the generated PNG doesn't litter the repo.
tmpDir := t.TempDir()
cfg := &config.Config{Output: config.OutputConfig{Directory: tmpDir}}
p := &workerPipeline{cfg: cfg}
w := worker.New(q, p, worker.Config{
PollInterval: 1 * time.Second,
JobTimeout: 30 * time.Second,
Logger: func(format string, a ...any) { t.Logf("worker: "+format, a...) },
})
// Run the worker until it processes one job (the one we just inserted)
// or the test context times out.
runCtx, runCancel := context.WithCancel(ctx)
done := make(chan struct{})
go func() {
_ = w.Run(runCtx)
close(done)
}()
// Poll for completion.
deadline := time.Now().Add(60 * time.Second)
var status, imageID string
for time.Now().Before(deadline) {
err = insertConn.QueryRow(ctx,
`SELECT status, COALESCE(image_id::text,'') FROM imagen.jobs WHERE id=$1`,
jobID).Scan(&status, &imageID)
if err != nil {
t.Fatalf("poll: %v", err)
}
if status == "done" || status == "failed" {
break
}
time.Sleep(500 * time.Millisecond)
}
runCancel()
<-done
if status != "done" {
var errText string
_ = insertConn.QueryRow(ctx,
`SELECT COALESCE(error,'') FROM imagen.jobs WHERE id=$1`, jobID).Scan(&errText)
t.Fatalf("job not done within timeout: status=%q error=%q", status, errText)
}
if imageID == "" {
t.Fatalf("job done but image_id is empty")
}
t.Logf("job done: image_id=%s", imageID)
}

View File

@@ -7,7 +7,7 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
```
┌───────────────────────┐
│ cmd/imagen │ CLI dispatch
│ cmd/imagen │ CLI dispatch (generate / worker / …)
│ (or HTTP server) │
└──────────┬────────────┘
@@ -15,6 +15,10 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
│ internal/prompt │ style preset → prompt suffix
│ internal/output │ filename templating, sidecar
│ internal/config │ YAML loader, validation
│ internal/preview │ tmux-img window spawner
│ internal/cloud │ Supabase Storage + imagen.images
│ internal/usage │ mai.imagen_usage cost-tracking
│ internal/worker │ imagen.jobs queue consumer
└──────────┬────────────┘
┌──────────▼────────────┐
@@ -102,9 +106,37 @@ contains the prompt, backend instance name, seed, ISO timestamp, and the
- Network errors during `Generate` — wrap and return; no retry policy yet
(decide per-adapter, or move to a shared retry helper if a pattern emerges).
## Async write path: `imagen worker` + `imagen.jobs`
`imagen generate` is the synchronous CLI. For web callers (flexsiebels'
owner-mode UI) `cmd/imagen worker` runs as a daemon that consumes the
`imagen.jobs` table.
```
flexsiebels POST imagen worker (mRiver, systemd)
→ INSERT INTO LISTEN imagen_jobs ◄── pg_notify trigger
imagen.jobs(pending) claim row (UPDATE … RETURNING)
dispatch through internal/backend
write disk + cloud-sync via internal/cloud
UPDATE imagen.jobs SET status='done', image_id=…
```
The queue table lives next to `imagen.images` in the same `imagen` schema.
Owner-scoped RLS lets the flexsiebels user INSERT + read their own rows;
the worker writes (status updates + image_id link) via service-role which
bypasses RLS. A 5-second safety poll fires on every wake-up to cover
dropped NOTIFY events and worker cold starts with a non-empty queue. See
`docs/setup-worker-mriver.md` for the systemd installation.
The worker reuses `internal/backend`, `internal/output`, and
`internal/cloud` unchanged — it is purely an orchestration layer around
the same pipeline `imagen generate` drives.
## Out of scope (today)
- Image post-processing (cropping, watermarking).
- Cost-tracking (lands with the Replicate adapter, since only API backends bill).
- Multi-image `n>1` per request — backends that support it can expose it via
`BackendOpts`; the framework doesn't have a first-class field yet.
- Job cancellation / kill switch — separate follow-up issue.
- Concurrent workers / multi-host scale-out — `FOR UPDATE SKIP LOCKED` in
the claim query makes it cheap to add, but a single worker is the v1 setup.

181
docs/setup-comfyui-mrock.md Normal file
View File

@@ -0,0 +1,181 @@
# ComfyUI on mRock — install + ops
ImaGen's `flux-schnell-local` backend talks to ComfyUI on mRock at
`http://mrock:8188` (Tailscale-internal). This document is the reproducible
install path from a clean mRock state.
mRock runs Arch Linux + systemd with an NVIDIA RTX 4070 Ti SUPER (16 GB
VRAM). Ollama is already a native systemd service, so ComfyUI follows the
same pattern (native Python venv + systemd unit) instead of Docker — Docker
on mRock has no `nvidia` runtime configured, and adding one is more invasive
than another systemd unit.
## Prerequisites on mRock
- Python via `uv` (already installed).
- NVIDIA driver new enough for CUDA 12.4. `nvidia-smi --query-gpu=driver_version`
should show >= 550. Driver 595 is what mRock has today.
- ~35 GB free on `/home` for the model files.
- `ollama.service` running on port 11434 — coexistence notes below.
## 1. Clone ComfyUI + Python venv
```bash
mkdir -p ~/dev && cd ~/dev
git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git comfyui
cd comfyui
uv venv --python 3.12 .venv
source .venv/bin/activate.fish
# PyTorch CUDA 12.4 wheels — match the system driver
uv pip install --no-cache torch torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu124
uv pip install --no-cache -r requirements.txt
```
Verify CUDA is wired up:
```bash
.venv/bin/python -c \
"import torch; print(torch.__version__, torch.cuda.is_available(), torch.cuda.get_device_name(0))"
# expected: 2.6.0+cu124 True NVIDIA GeForce RTX 4070 Ti SUPER
```
## 2. Models — FLUX.1 schnell
The Black-Forest-Labs primary repo (`black-forest-labs/FLUX.1-schnell`) is
**gated**`curl` against it without an HF token returns HTTP 401. We pull
the weights from ungated mirrors of the same Apache-2.0 release.
| File | Where it goes | Source |
|------|---------------|--------|
| `flux1-schnell.safetensors` (~23.8 GB, fp16) | `models/unet/` | `Comfy-Org/flux1-schnell` |
| `ae.safetensors` (~335 MB) | `models/vae/` | `sirorable/flux-ae-vae` |
| `clip_l.safetensors` (~246 MB) | `models/clip/` | `comfyanonymous/flux_text_encoders` |
| `t5xxl_fp8_e4m3fn.safetensors` (~4.9 GB) | `models/clip/` | `comfyanonymous/flux_text_encoders` |
```bash
cd ~/dev/comfyui/models
curl -L -o unet/flux1-schnell.safetensors \
https://huggingface.co/Comfy-Org/flux1-schnell/resolve/main/flux1-schnell.safetensors
curl -L -o vae/ae.safetensors \
https://huggingface.co/sirorable/flux-ae-vae/resolve/main/ae.safetensors
curl -L -o clip/clip_l.safetensors \
https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors
curl -L -o clip/t5xxl_fp8_e4m3fn.safetensors \
https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors
```
If a new HF token is configured later (`~/.cache/huggingface/token`), the
official `black-forest-labs/FLUX.1-schnell` URL is byte-identical and can be
swapped in.
## 3. systemd unit
Drop `/etc/systemd/system/comfyui.service`:
```ini
[Unit]
Description=ComfyUI image generation server
Documentation=https://github.com/comfyanonymous/ComfyUI
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
User=m
Group=m
WorkingDirectory=/home/m/dev/comfyui
ExecStart=/home/m/dev/comfyui/.venv/bin/python /home/m/dev/comfyui/main.py \
--listen 0.0.0.0 --port 8188 \
--output-directory /home/m/dev/comfyui/output \
--temp-directory /home/m/dev/comfyui/temp
Restart=on-failure
RestartSec=5
TimeoutStopSec=30
NoNewPrivileges=true
PrivateTmp=true
LimitNOFILE=65535
[Install]
WantedBy=multi-user.target
```
Then:
```bash
sudo systemctl daemon-reload
sudo systemctl enable --now comfyui.service
systemctl status comfyui.service
```
The service binds `0.0.0.0:8188`. Tailscale's wireguard fence is the only
auth — do **not** expose port 8188 to the public internet.
## 4. Health check
```bash
curl -fsS --max-time 5 http://mrock:8188/system_stats | jq '.devices[0]'
# expected: name "cuda:0 NVIDIA GeForce RTX 4070 Ti SUPER ...", vram_total ~16 GB
```
`imagen backends` (from a host with the ImaGen CLI installed) should also
report `flux-schnell-local: ok`.
## 5. VRAM coexistence with Ollama
mRock has 16 GB VRAM total. Ollama parks ~8 GB resident for its current
model. FLUX schnell at fp16 weights with `weight_dtype=fp8_e4m3fn` (the
default the adapter requests) needs roughly 1012 GB peak for a 1024×1024
generation, so concurrent Ollama + FLUX on mRock will OOM.
Two practical options:
- **Stop Ollama before generating** — `sudo systemctl stop ollama` frees
the GPU, run the generation, `sudo systemctl start ollama` afterwards.
Adequate while we don't have many concurrent users.
- **Move Ollama off mRock** — when ImaGen is in regular use, push Ollama to
another host so the GPU is dedicated. Tracked separately.
Both decisions live with whoever operates the box; the adapter does not try
to manage Ollama.
## 6. Smoke test (direct, without the imagen CLI)
```bash
# 1) Submit a workflow
curl -fsS --max-time 30 -X POST -H 'Content-Type: application/json' \
-d @flux-schnell-workflow.json \
http://mrock:8188/prompt
# returns: {"prompt_id": "...", "number": ..., "node_errors": {}}
# 2) Poll history until the prompt completes
PID=... # from above
until curl -fsS http://mrock:8188/history/$PID | jq -e ".\"$PID\".status.completed == true" >/dev/null; do
sleep 1
done
# 3) Pull the image
NAME=$(curl -fsS http://mrock:8188/history/$PID \
| jq -r ".\"$PID\".outputs[\"9\"].images[0].filename")
curl -fsS "http://mrock:8188/view?filename=$NAME&type=output" -o /tmp/cat.png
file /tmp/cat.png # PNG image data, 1024 x 1024
```
The full ImaGen smoke test is in [usage.md](usage.md) once the Go adapter
ships.
## Troubleshooting
- **`vram_free` < 6 GB in `/system_stats`**: another GPU process is holding
memory. Usually Ollama (`sudo systemctl stop ollama`).
- **Workflow returns `node_errors` with `Required input is missing` for
CLIPLoader**: text encoder filenames don't match step 2 check that
`clip_l.safetensors` and `t5xxl_fp8_e4m3fn.safetensors` are in
`models/clip/`, not `models/text_encoders/`.
- **`Access to model … is restricted`** during a model pull: the script is
hitting a gated mirror. Use the ungated URLs from step 2.
- **Service won't start**: check `journalctl -u comfyui --since '5 min ago'`.
Common cause is a stale `pip` install re-run step 1.

View File

@@ -0,0 +1,97 @@
# `imagen worker` on mRiver
The worker is a long-running daemon that consumes the `imagen.jobs` queue
(written by flexsiebels' owner-mode UI) and writes the resulting image to
Supabase Storage + `imagen.images` via the same cloud-sync path the CLI
`imagen generate` uses.
## Architecture
```
flexsiebels (owner UI)
|
v INSERT INTO imagen.jobs (...)
|
msupabase Postgres
|
| AFTER INSERT trigger:
| pg_notify('imagen_jobs', NEW.id)
v
imagen worker (mRiver) ── LISTEN imagen_jobs
|
| 1. claim oldest 'pending' row (status='running')
| 2. dispatch to backend (FLUX schnell local / FLUX dev replicate / …)
| 3. write PNG to disk
| 4. upload to Storage + INSERT into imagen.images
| 5. UPDATE imagen.jobs SET status='done', image_id=...
v
flexsiebels polls GET .../jobs/<id> → renders the rendered card
```
A 5-second safety poll covers dropped NOTIFY events and worker cold starts
with a non-empty queue.
## One-time setup
```bash
# 1. Build the binary (or `task build`).
cd ~/dev/ImaGen
go build -o bin/imagen ./cmd/imagen
# 2. Write the environment file.
cp scripts/imagen-worker.env.example ~/.dotfiles/.env.imagen-worker
chmod 600 ~/.dotfiles/.env.imagen-worker
$EDITOR ~/.dotfiles/.env.imagen-worker # fill in real DSN, service key
# 3. Install the user systemd unit.
mkdir -p ~/.config/systemd/user
cp scripts/imagen-worker.service ~/.config/systemd/user/imagen-worker.service
systemctl --user daemon-reload
systemctl --user enable --now imagen-worker.service
# 4. Tail the logs.
journalctl --user -u imagen-worker -f
```
## Required env vars
See `scripts/imagen-worker.env.example` for the canonical list. Required:
- `IMAGEN_WORKER_DATABASE_URL` — direct Postgres DSN. PostgREST cannot LISTEN.
- `SUPABASE_URL`, `SUPABASE_SERVICE_KEY` — same pair `imagen generate`
reads for the cloud-sync writer.
- `IMAGEN_OWNER_USER_ID` — fallback owner UUID; per-job row's
`owner_user_id` overrides this.
Optional, depending on enabled backends:
- `REPLICATE_API_TOKEN` if any job will request a Replicate-typed backend.
## Operating
```bash
systemctl --user status imagen-worker # health
systemctl --user restart imagen-worker # pick up a new binary
journalctl --user -u imagen-worker -n 200 # recent log lines
```
On startup the worker calls `ResetStaleRunning` once, flipping any rows
left in `'running'` from a previous crash back to `'pending'` so they get
re-claimed by the 5-second poll.
## Smoke test
With the worker running, INSERT a test job:
```sql
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
VALUES (
'ac6c9501-3757-4a6d-8b97-2cff4288382b',
'a tiny owl wearing wire-rim glasses, photo',
'flux-schnell-local', 1024, 1024
);
```
Within ~10 seconds the row should show `status='done'`, a populated
`image_id` linking to a real `imagen.images` row, and a Storage object at
`<YYYY-MM-DD>/<slug>-<seed>.png` in the `imagen-generated` bucket.

View File

@@ -24,8 +24,29 @@ imagen version print version
| `--negative` | empty | Negative prompt (ignored by some adapters) |
| `--output` | empty (= use naming template) | Explicit path |
| `--no-sidecar` | `false` | Skip the JSON sidecar even if config enables it |
| `--preview` | (auto) | Force open a tmux preview window via `tmux-img` |
| `--no-preview` | (auto) | Suppress the preview window (use for batch / CI callers) |
| `--no-cloud` | `false` | Skip Supabase upload + `imagen.images` insert for this call |
| `--config` | `~/.config/imagen.yaml` | Override config path |
### Preview window
After a successful generate, imagen optionally opens a sibling tmux window
named `img:<slug>` running `tmux-img --hold <path>`. The new window is
spawned in the background (`tmux new-window -d`) so the generating pane
keeps focus and its terminal output.
Resolution order is **config → `$IMAGEN_PREVIEW` → flag** (later wins):
- `output.preview` in `imagen.yaml`: `auto` (default) | `on` | `off`
- `IMAGEN_PREVIEW=auto|on|off` overrides config
- `--preview` / `--no-preview` override env
`auto` previews iff stdout is a TTY *and* `$TMUX` is set. `on` previews
unconditionally and errors outside a tmux session. `off` never previews.
Preview failures are non-fatal — the image already wrote.
## Examples
```sh
@@ -71,3 +92,53 @@ API-backed adapters read tokens from env vars referenced by the config
export REPLICATE_API_TOKEN=...
imagen generate "a cat" --backend flux-dev-replicate
```
## Cost-tracking (Replicate)
Successful generations through the Replicate adapter write one row to
`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost
estimate, prompt sha256 hash (never the prompt itself), and the caller
identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`).
The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are
unset, or the database write fails, the image still lands and the CLI
prints a warning to stderr.
Inspect spend:
```sh
imagen usage # all rows, grouped by week + backend + model + caller
imagen usage --since 2026-05-01 # only rows on/after a UTC date
imagen usage --since 2026-05-01 --raw
```
Per-model rates live in `internal/backend/replicate_pricing.go` — they
are snapshotted from <https://replicate.com/pricing> and refreshed on a
quarterly cadence.
## Cloud-sync (Supabase)
Successful generations also upload the PNG to the private Supabase
Storage bucket `imagen-generated` (path: `<YYYY-MM-DD>/<slug>-<seed>.png`)
and insert a row into `imagen.images`. The row carries the prompt,
sha256-hashed prompt, backend, model, seed/steps/width/height, latency,
cost estimate, the full local sidecar JSON, and an empty `tags` array
ready for the flexsiebels viewer to fill in.
Configuration:
- `owner_user_id` in `imagen.yaml` — m's `auth.users.id`. Empty disables
inserts (the column is `NOT NULL`).
- `output.cloud_sync` in `imagen.yaml`: `auto` (default — on iff
SUPABASE creds + `owner_user_id` are set), `on` (errors if either is
missing), `off`.
- `IMAGEN_CLOUD_SYNC=auto|on|off` overrides config.
- `--no-cloud` overrides everything for one call.
Reuses the same Supabase env (`SUPABASE_URL` + `SUPABASE_SERVICE_KEY` or
`MAI_SUPABASE_KEY`) as cost-tracking. Service-role bypasses RLS for
inserts; the `owner_user_id = auth.uid()` policy on the table gates the
read path the flexsiebels viewer hits.
Failures (Storage 5xx, DB unreachable) emit `imagen: cloud sync: <err>`
to stderr and the local PNG + sidecar stay put. Exit code is unchanged.

15
go.mod
View File

@@ -1,5 +1,16 @@
module mgit.msbls.de/m/ImaGen
go 1.24
go 1.25.0
require gopkg.in/yaml.v3 v3.0.1
require (
github.com/jackc/pgx/v5 v5.9.2
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/text v0.29.0 // indirect
)

33
go.sum
View File

@@ -1,4 +1,35 @@
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

557
internal/backend/comfyui.go Normal file
View File

@@ -0,0 +1,557 @@
package backend
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// ComfyType is the type-name adapters register under for ComfyUI instances.
const ComfyType = "comfyui"
// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history`
// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from
// the values in Request.
//
// Concurrency: a single Comfy is safe to share across goroutines as long as
// the underlying http.Client is. Generate does not hold long-lived state.
type Comfy struct {
instance string
base string
model string
vae string
clipL string
clipT5 string
dtype string
defaultSteps int
defaultSampler string
defaultScheduler string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
clientIDFn func() string
}
// NewComfy is the registry constructor. cfg is the adapter's slice of
// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX
// schnell defaults.
func NewComfy(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("comfyui: empty instance name")
}
base := strings.TrimRight(getString(cfg, "base_url", ""), "/")
if base == "" {
return nil, fmt.Errorf("comfyui[%s]: base_url is required", name)
}
if _, err := url.Parse(base); err != nil {
return nil, fmt.Errorf("comfyui[%s]: base_url %q invalid: %w", name, base, err)
}
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("comfyui[%s]: model is required", name)
}
c := &Comfy{
instance: name,
base: base,
model: model,
vae: getString(cfg, "vae", "ae.safetensors"),
clipL: getString(cfg, "clip_l", "clip_l.safetensors"),
clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"),
dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"),
defaultSteps: getInt(cfg, "default_steps", 4),
defaultSampler: getString(cfg, "default_sampler", "euler"),
defaultScheduler: getString(cfg, "default_scheduler", "simple"),
httpClient: &http.Client{Timeout: 60 * time.Second},
pollInterval: 250 * time.Millisecond,
pollTimeout: 120 * time.Second,
randSeed: cryptoSeed,
clientIDFn: randClientID,
}
return c, nil
}
// Name returns the instance name from imagen.yaml.
func (c *Comfy) Name() string { return c.instance }
// Generate submits one workflow to ComfyUI, waits for it to render, and
// returns the resulting PNG.
func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
steps := orDefaultInt(req.Steps, c.defaultSteps)
sampler := c.defaultSampler
scheduler := c.defaultScheduler
if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" {
sampler = v
}
if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" {
scheduler = v
}
seed := req.Seed
if seed == 0 {
seed = c.randSeed()
}
workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler)
clientID := c.clientIDFn()
start := time.Now()
promptID, err := c.submitPrompt(ctx, workflow, clientID)
if err != nil {
return nil, err
}
filename, err := c.waitForCompletion(ctx, promptID)
if err != nil {
return nil, err
}
imgBytes, err := c.fetchImage(ctx, filename)
if err != nil {
return nil, err
}
latencyMs := time.Since(start).Milliseconds()
meta := map[string]any{
"backend": c.instance,
"backend_type": ComfyType,
"model": c.model,
"seed": seed,
"steps": steps,
"sampler": sampler,
"scheduler": scheduler,
"width": width,
"height": height,
"latency_ms": latencyMs,
"prompt_id": promptID,
"client_id": clientID,
}
if vram := c.vramUsedMiB(ctx); vram > 0 {
meta["vram_used_mib"] = vram
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: "image/png",
Metadata: meta,
}, nil
}
// submitPrompt POSTs the workflow and extracts the prompt_id.
//
// Retries once on a 5xx or transient network error. 4xx responses are not
// retried — they are treated as configuration bugs (missing model, bad
// workflow shape, etc.) and surfaced with a hint pointing at the docs when
// the body matches a known pattern.
func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clientID string) (string, error) {
body, err := json.Marshal(map[string]any{
"prompt": workflow,
"client_id": clientID,
})
if err != nil {
return "", fmt.Errorf("comfyui: marshal workflow: %w", err)
}
var lastErr error
for attempt := range 2 {
if attempt > 0 {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(time.Second):
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/prompt", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = c.connError(err)
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return parsePromptID(respBody, c.model)
case resp.StatusCode >= 500:
lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody))
continue
default:
return "", c.classifyBadRequest(resp.StatusCode, respBody)
}
}
return "", lastErr
}
// waitForCompletion polls /history/{id} until the prompt finishes and
// returns the filename of the produced image.
func (c *Comfy) waitForCompletion(ctx context.Context, promptID string) (string, error) {
deadline := time.Now().Add(c.pollTimeout)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
if time.Now().After(deadline) {
return "", fmt.Errorf("comfyui: prompt %s did not complete within %s", promptID, c.pollTimeout)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/history/"+promptID, nil)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", c.connError(err)
}
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("comfyui /history/%s %d: %s", promptID, resp.StatusCode, snip(body))
}
filename, done, err := parseHistory(body, promptID)
if err != nil {
return "", err
}
if done {
return filename, nil
}
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(c.pollInterval):
}
}
}
// fetchImage downloads the produced image bytes via /view.
func (c *Comfy) fetchImage(ctx context.Context, filename string) ([]byte, error) {
q := url.Values{
"filename": {filename},
"type": {"output"},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/view?"+q.Encode(), nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, c.connError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("comfyui /view %d: %s", resp.StatusCode, snip(body))
}
return io.ReadAll(resp.Body)
}
// vramUsedMiB returns total - free VRAM on device 0 from /system_stats, or
// 0 if the endpoint isn't available. Best-effort, never an error.
func (c *Comfy) vramUsedMiB(ctx context.Context) int64 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.base+"/system_stats", nil)
if err != nil {
return 0
}
resp, err := c.httpClient.Do(req)
if err != nil {
return 0
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0
}
var s struct {
Devices []struct {
VRAMTotal int64 `json:"vram_total"`
VRAMFree int64 `json:"vram_free"`
} `json:"devices"`
}
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
return 0
}
if len(s.Devices) == 0 {
return 0
}
used := s.Devices[0].VRAMTotal - s.Devices[0].VRAMFree
if used < 0 {
return 0
}
return used / (1024 * 1024)
}
// connError translates a Go networking error into a user-actionable message,
// pointing at the boot-whitetower script when mRock looks asleep.
func (c *Comfy) connError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return err
}
msg := err.Error()
var opErr *net.OpError
asOp := errors.As(err, &opErr)
switch {
case asOp,
strings.Contains(msg, "connection refused"),
strings.Contains(msg, "no such host"),
strings.Contains(msg, "no route to host"),
strings.Contains(msg, "network is unreachable"),
strings.Contains(msg, "i/o timeout"):
return fmt.Errorf("comfyui at %s unreachable (%v) — if mRock is asleep, run: boot-whitetower mrock", c.base, err)
}
return fmt.Errorf("comfyui at %s: %w", c.base, err)
}
// classifyBadRequest interprets a 4xx body. Some ComfyUI builds use 400 for
// workflow-validation failures and put the diagnostics in node_errors; older
// builds use 200 + node_errors. This handles the 4xx flavour.
func (c *Comfy) classifyBadRequest(status int, body []byte) error {
if hint, ok := missingModelHint(body, c.model); ok {
return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint)
}
return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body))
}
// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow,
// node-IDs matching the upstream "flux-schnell" template so anyone debugging
// in the ComfyUI UI sees a familiar shape.
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any {
return map[string]any{
"6": map[string]any{
"class_type": "CLIPTextEncode",
"inputs": map[string]any{
"text": prompt,
"clip": []any{"11", 0},
},
},
"8": map[string]any{
"class_type": "VAEDecode",
"inputs": map[string]any{
"samples": []any{"31", 0},
"vae": []any{"10", 0},
},
},
"9": map[string]any{
"class_type": "SaveImage",
"inputs": map[string]any{
"filename_prefix": "imagen",
"images": []any{"8", 0},
},
},
"10": map[string]any{
"class_type": "VAELoader",
"inputs": map[string]any{"vae_name": c.vae},
},
"11": map[string]any{
"class_type": "DualCLIPLoader",
"inputs": map[string]any{
"clip_name1": c.clipT5,
"clip_name2": c.clipL,
"type": "flux",
},
},
"12": map[string]any{
"class_type": "UNETLoader",
"inputs": map[string]any{
"unet_name": c.model,
"weight_dtype": c.dtype,
},
},
"13": map[string]any{
"class_type": "CLIPTextEncode",
"inputs": map[string]any{
"text": negative,
"clip": []any{"11", 0},
},
},
"27": map[string]any{
"class_type": "EmptySD3LatentImage",
"inputs": map[string]any{
"width": w,
"height": h,
"batch_size": 1,
},
},
"30": map[string]any{
"class_type": "ModelSamplingFlux",
"inputs": map[string]any{
"model": []any{"12", 0},
"max_shift": 1.15,
"base_shift": 0.5,
"width": w,
"height": h,
},
},
"31": map[string]any{
"class_type": "KSampler",
"inputs": map[string]any{
"model": []any{"30", 0},
"seed": seed,
"steps": steps,
"cfg": 1.0,
"sampler_name": sampler,
"scheduler": scheduler,
"denoise": 1.0,
"positive": []any{"6", 0},
"negative": []any{"13", 0},
"latent_image": []any{"27", 0},
},
},
}
}
// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a
// validation failure and stuffs node_errors in the body — this function
// turns that into the same user-facing error as a 4xx with the same body.
func parsePromptID(body []byte, model string) (string, error) {
var resp struct {
PromptID string `json:"prompt_id"`
NodeErrors map[string]any `json:"node_errors"`
Error json.RawMessage `json:"error"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("comfyui /prompt: parse response: %w (body: %s)", err, snip(body))
}
if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 {
if hint, ok := missingModelHint(body, model); ok {
return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint)
}
return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body))
}
if resp.PromptID == "" {
return "", fmt.Errorf("comfyui /prompt: empty prompt_id (body: %s)", snip(body))
}
return resp.PromptID, nil
}
// parseHistory inspects a /history/{id} body and returns either the produced
// filename + done=true, or done=false to signal "keep polling".
func parseHistory(body []byte, promptID string) (string, bool, error) {
var entries map[string]struct {
Status struct {
Completed bool `json:"completed"`
StatusStr string `json:"status_str"`
} `json:"status"`
Outputs map[string]struct {
Images []struct {
Filename string `json:"filename"`
Subfolder string `json:"subfolder"`
Type string `json:"type"`
} `json:"images"`
} `json:"outputs"`
}
if err := json.Unmarshal(body, &entries); err != nil {
return "", false, fmt.Errorf("comfyui /history: parse: %w (body: %s)", err, snip(body))
}
e, ok := entries[promptID]
if !ok {
return "", false, nil
}
if e.Status.StatusStr == "error" {
return "", false, fmt.Errorf("comfyui prompt %s errored: %s", promptID, snip(body))
}
if !e.Status.Completed {
return "", false, nil
}
for _, out := range e.Outputs {
if len(out.Images) > 0 {
return out.Images[0].Filename, true, nil
}
}
return "", true, fmt.Errorf("comfyui prompt %s completed but produced no images", promptID)
}
// missingModelHint returns a user-actionable message when the response body
// indicates the configured unet model isn't loaded on the server. ComfyUI
// uses both the human-readable "Value not in list" message and the enum
// "value_not_in_list" type — match either.
func missingModelHint(body []byte, model string) (string, bool) {
s := string(body)
hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list")
if hasMarker && strings.Contains(s, "unet_name") {
return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true
}
return "", false
}
func cryptoSeed() int64 {
var b [8]byte
if _, err := rand.Read(b[:]); err != nil {
return time.Now().UnixNano()
}
return int64(binary.BigEndian.Uint64(b[:]) >> 1)
}
func randClientID() string {
var b [8]byte
_, _ = rand.Read(b[:])
return fmt.Sprintf("imagen-%x", b)
}
func getString(m map[string]any, k, def string) string {
if v, ok := m[k].(string); ok && v != "" {
return v
}
return def
}
func getInt(m map[string]any, k string, def int) int {
if v, ok := m[k]; ok {
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
}
return def
}
func orDefaultInt(v, def int) int {
if v == 0 {
return def
}
return v
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}
func init() {
Register(ComfyType, NewComfy)
}

View File

@@ -0,0 +1,494 @@
package backend
import (
"bytes"
"context"
"encoding/json"
"fmt"
"image"
"image/color"
"image/png"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)
// fakeComfy is a programmable mock of the ComfyUI HTTP API. Tests configure
// its behaviour by adjusting the public fields before issuing the request.
type fakeComfy struct {
t *testing.T
// /prompt
promptStatus int
promptBody []byte
promptCalls atomic.Int32
failPromptUntil int32 // first N /prompt calls return promptFailStatus
promptFailStatus int
promptFailBody []byte
// /history — start by returning {} (no entry), flip to completed once
// historyReadyAfter polls have happened.
historyReadyAfter int32
historyCalls atomic.Int32
historyError bool
// /view
viewStatus int
viewBody []byte
viewType string
// /system_stats
statsTotal int64
statsFree int64
server *httptest.Server
}
func (f *fakeComfy) handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/prompt" && r.Method == http.MethodPost:
n := f.promptCalls.Add(1)
if n <= int32(f.failPromptUntil) {
w.WriteHeader(f.promptFailStatus)
_, _ = w.Write(f.promptFailBody)
return
}
w.WriteHeader(f.promptStatus)
_, _ = w.Write(f.promptBody)
case strings.HasPrefix(r.URL.Path, "/history/") && r.Method == http.MethodGet:
n := f.historyCalls.Add(1)
id := strings.TrimPrefix(r.URL.Path, "/history/")
w.WriteHeader(http.StatusOK)
if f.historyError {
_, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":false,"status_str":"error"},"outputs":{}}}`, id)
return
}
if n <= f.historyReadyAfter {
_, _ = w.Write([]byte(`{}`))
return
}
_, _ = fmt.Fprintf(w,
`{"%s":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`,
id,
)
case r.URL.Path == "/view" && r.Method == http.MethodGet:
ct := f.viewType
if ct == "" {
ct = "image/png"
}
w.Header().Set("Content-Type", ct)
w.WriteHeader(f.viewStatus)
_, _ = w.Write(f.viewBody)
case r.URL.Path == "/system_stats" && r.Method == http.MethodGet:
w.Header().Set("Content-Type", "application/json")
body := map[string]any{
"system": map[string]any{},
"devices": []map[string]any{
{"vram_total": f.statsTotal, "vram_free": f.statsFree},
},
}
_ = json.NewEncoder(w).Encode(body)
default:
f.t.Errorf("fakeComfy: unexpected request %s %s", r.Method, r.URL.Path)
http.NotFound(w, r)
}
})
}
func (f *fakeComfy) start() {
f.server = httptest.NewServer(f.handler())
f.t.Cleanup(f.server.Close)
}
// newFakeComfy spins up a fakeComfy with happy-path defaults.
func newFakeComfy(t *testing.T) *fakeComfy {
t.Helper()
f := &fakeComfy{
t: t,
promptStatus: http.StatusOK,
promptBody: []byte(`{"prompt_id":"pid-abc","number":1,"node_errors":{}}`),
viewStatus: http.StatusOK,
viewBody: mustPNG(t, 16, 16),
statsTotal: 16 * 1024 * 1024 * 1024,
statsFree: 8 * 1024 * 1024 * 1024,
}
f.start()
return f
}
// newComfy returns a Comfy pointed at f, with poll interval squashed for fast
// tests and deterministic seed/client_id.
func newComfy(t *testing.T, f *fakeComfy) *Comfy {
t.Helper()
be, err := NewComfy("flux-test", map[string]any{
"base_url": f.server.URL,
"model": "flux1-schnell.safetensors",
"default_steps": 4,
})
if err != nil {
t.Fatalf("NewComfy: %v", err)
}
c := be.(*Comfy)
c.pollInterval = time.Millisecond
c.pollTimeout = 5 * time.Second
c.randSeed = func() int64 { return 42 }
c.clientIDFn = func() string { return "imagen-test" }
return c
}
func mustPNG(t *testing.T, w, h int) []byte {
t.Helper()
img := image.NewRGBA(image.Rect(0, 0, w, h))
for y := range h {
for x := range w {
img.Set(x, y, color.RGBA{R: 200, G: 100, B: 50, A: 255})
}
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatalf("encode png: %v", err)
}
return buf.Bytes()
}
func TestComfyConstructorRequiresBaseAndModel(t *testing.T) {
if _, err := NewComfy("x", map[string]any{}); err == nil {
t.Errorf("expected error for missing base_url")
}
if _, err := NewComfy("x", map[string]any{"base_url": "http://h:1"}); err == nil {
t.Errorf("expected error for missing model")
}
if _, err := NewComfy("", map[string]any{"base_url": "http://h:1", "model": "m"}); err == nil {
t.Errorf("expected error for empty instance name")
}
}
func TestComfyHappyPath(t *testing.T) {
f := newFakeComfy(t)
f.historyReadyAfter = 2 // exercise the polling loop
c := newComfy(t, f)
res, err := c.Generate(context.Background(), Request{
Prompt: "a small fishbowl with a cat",
Width: 512,
Height: 512,
Steps: 4,
Seed: 1234567,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
defer res.ImageReader.Close()
if res.MimeType != "image/png" {
t.Errorf("mime = %q", res.MimeType)
}
body, err := io.ReadAll(res.ImageReader)
if err != nil {
t.Fatalf("read body: %v", err)
}
if !bytes.Equal(body, f.viewBody) {
t.Errorf("image body did not round-trip")
}
if seed, _ := res.Metadata["seed"].(int64); seed != 1234567 {
t.Errorf("metadata seed = %v", res.Metadata["seed"])
}
if model, _ := res.Metadata["model"].(string); model != "flux1-schnell.safetensors" {
t.Errorf("metadata model = %v", res.Metadata["model"])
}
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
t.Errorf("metadata steps = %v", res.Metadata["steps"])
}
if pid, _ := res.Metadata["prompt_id"].(string); pid != "pid-abc" {
t.Errorf("metadata prompt_id = %v", res.Metadata["prompt_id"])
}
if _, ok := res.Metadata["latency_ms"]; !ok {
t.Errorf("metadata missing latency_ms")
}
// vram_used_mib is best-effort but should be present given our mock stats
if vram, _ := res.Metadata["vram_used_mib"].(int64); vram != 8192 {
t.Errorf("metadata vram_used_mib = %v, want 8192", res.Metadata["vram_used_mib"])
}
if got := f.historyCalls.Load(); got < 3 {
t.Errorf("expected at least 3 /history polls, got %d", got)
}
}
func TestComfyDefaultsAppliedWhenZero(t *testing.T) {
f := newFakeComfy(t)
c := newComfy(t, f)
res, err := c.Generate(context.Background(), Request{Prompt: "p"}) // all-zero
if err != nil {
t.Fatalf("Generate: %v", err)
}
defer res.ImageReader.Close()
_, _ = io.ReadAll(res.ImageReader)
if w, _ := res.Metadata["width"].(int); w != 1024 {
t.Errorf("width default = %v", res.Metadata["width"])
}
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
t.Errorf("steps default = %v", res.Metadata["steps"])
}
if seed, _ := res.Metadata["seed"].(int64); seed != 42 {
t.Errorf("seed default (test rand hook) = %v", res.Metadata["seed"])
}
if s, _ := res.Metadata["sampler"].(string); s != "euler" {
t.Errorf("sampler default = %q", s)
}
}
func TestComfyPromptRetriesOnce5xx(t *testing.T) {
f := newFakeComfy(t)
f.failPromptUntil = 1
f.promptFailStatus = http.StatusBadGateway
f.promptFailBody = []byte("upstream busy")
c := newComfy(t, f)
res, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err != nil {
t.Fatalf("Generate (with one 502 then OK): %v", err)
}
defer res.ImageReader.Close()
_, _ = io.ReadAll(res.ImageReader)
if got := f.promptCalls.Load(); got != 2 {
t.Errorf("expected exactly 2 /prompt calls (1 fail + 1 retry), got %d", got)
}
}
func TestComfyPromptGivesUpAfterTwo5xx(t *testing.T) {
f := newFakeComfy(t)
f.failPromptUntil = 99 // every call fails
f.promptFailStatus = http.StatusServiceUnavailable
f.promptFailBody = []byte("nope")
c := newComfy(t, f)
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error after sustained 503s")
}
if !strings.Contains(err.Error(), "503") {
t.Errorf("expected error to mention 503, got %v", err)
}
if got := f.promptCalls.Load(); got != 2 {
t.Errorf("expected exactly 2 /prompt calls (no further retries), got %d", got)
}
}
func TestComfyPromptDoesNotRetryOn4xx(t *testing.T) {
f := newFakeComfy(t)
f.failPromptUntil = 99
f.promptFailStatus = http.StatusBadRequest
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation"},"node_errors":{"some":"thing"}}`)
c := newComfy(t, f)
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error for 400")
}
if got := f.promptCalls.Load(); got != 1 {
t.Errorf("expected exactly 1 /prompt call (no retry on 4xx), got %d", got)
}
}
func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) {
f := newFakeComfy(t)
f.failPromptUntil = 99
f.promptFailStatus = http.StatusBadRequest
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation","message":"Prompt outputs failed validation"},"node_errors":{"12":{"errors":[{"type":"value_not_in_list","message":"Value not in list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
c := newComfy(t, f)
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") {
t.Errorf("error should point at the setup doc, got %v", err)
}
if !strings.Contains(msg, "flux1-schnell.safetensors") {
t.Errorf("error should name the missing model, got %v", err)
}
}
func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) {
// Older ComfyUI builds 200 a workflow-validation failure.
f := newFakeComfy(t)
f.promptStatus = http.StatusOK
f.promptBody = []byte(`{"prompt_id":"","node_errors":{"12":{"errors":[{"type":"value_not_in_list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
c := newComfy(t, f)
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error for node_errors on 200")
}
if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") {
t.Errorf("error should point at the setup doc, got %v", err)
}
}
func TestComfyHistoryErrorSurfaced(t *testing.T) {
f := newFakeComfy(t)
f.historyError = true
c := newComfy(t, f)
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error when history reports execution error")
}
if !strings.Contains(err.Error(), "errored") {
t.Errorf("expected 'errored' in message, got %v", err)
}
}
func TestComfyViewFailureSurfaced(t *testing.T) {
f := newFakeComfy(t)
f.viewStatus = http.StatusNotFound
f.viewBody = []byte("nope")
c := newComfy(t, f)
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error when /view 404s")
}
if !strings.Contains(err.Error(), "404") {
t.Errorf("expected status code in error, got %v", err)
}
}
func TestComfyUnreachableHostMentionsBootHelper(t *testing.T) {
be, err := NewComfy("flux-test", map[string]any{
"base_url": "http://127.0.0.1:1", // closed port; connection refused
"model": "flux1-schnell.safetensors",
})
if err != nil {
t.Fatalf("NewComfy: %v", err)
}
c := be.(*Comfy)
c.httpClient = &http.Client{Timeout: 500 * time.Millisecond}
_, err = c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected error for unreachable host")
}
if !strings.Contains(err.Error(), "boot-whitetower mrock") {
t.Errorf("expected boot-helper hint, got %v", err)
}
}
func TestComfyContextCancelStopsPolling(t *testing.T) {
f := newFakeComfy(t)
f.historyReadyAfter = 1_000_000 // never finishes
c := newComfy(t, f)
c.pollInterval = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
_, err := c.Generate(ctx, Request{Prompt: "p", Width: 64, Height: 64})
if err == nil {
t.Fatal("expected ctx.Err()")
}
if !strings.Contains(err.Error(), "context deadline exceeded") {
t.Errorf("expected deadline exceeded, got %v", err)
}
}
func TestComfyWorkflowReflectsRequest(t *testing.T) {
// Capture the workflow body to assert KSampler + EmptyLatentImage values.
var captured []byte
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/prompt":
captured, _ = io.ReadAll(r.Body)
_, _ = w.Write([]byte(`{"prompt_id":"pid","number":1,"node_errors":{}}`))
case "/history/pid":
_, _ = w.Write([]byte(`{"pid":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`))
case "/view":
_, _ = w.Write(mustPNG(t, 8, 8))
case "/system_stats":
_, _ = w.Write([]byte(`{"devices":[{"vram_total":1,"vram_free":1}]}`))
default:
http.NotFound(w, r)
}
}))
t.Cleanup(srv.Close)
be, err := NewComfy("flux-test", map[string]any{
"base_url": srv.URL,
"model": "custom.safetensors",
"default_steps": 7,
"default_sampler": "dpmpp_2m",
"default_scheduler": "karras",
})
if err != nil {
t.Fatalf("NewComfy: %v", err)
}
c := be.(*Comfy)
c.pollInterval = time.Millisecond
c.randSeed = func() int64 { return 9999 }
res, err := c.Generate(context.Background(), Request{
Prompt: "a cat",
NegativePrompt: "blurry",
Width: 768,
Height: 512,
Steps: 11,
Seed: 555,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
res.ImageReader.Close()
var sent struct {
Prompt map[string]map[string]any `json:"prompt"`
ClientID string `json:"client_id"`
}
if err := json.Unmarshal(captured, &sent); err != nil {
t.Fatalf("unmarshal captured: %v", err)
}
ks := sent.Prompt["31"]["inputs"].(map[string]any)
if ks["seed"].(float64) != 555 {
t.Errorf("KSampler seed = %v, want 555", ks["seed"])
}
if ks["steps"].(float64) != 11 {
t.Errorf("KSampler steps = %v, want 11", ks["steps"])
}
if ks["sampler_name"].(string) != "dpmpp_2m" {
t.Errorf("sampler_name = %v", ks["sampler_name"])
}
if ks["scheduler"].(string) != "karras" {
t.Errorf("scheduler = %v", ks["scheduler"])
}
latent := sent.Prompt["27"]["inputs"].(map[string]any)
if latent["width"].(float64) != 768 || latent["height"].(float64) != 512 {
t.Errorf("EmptySD3LatentImage size = %vx%v", latent["width"], latent["height"])
}
unet := sent.Prompt["12"]["inputs"].(map[string]any)
if unet["unet_name"].(string) != "custom.safetensors" {
t.Errorf("unet_name = %v", unet["unet_name"])
}
neg := sent.Prompt["13"]["inputs"].(map[string]any)
if neg["text"].(string) != "blurry" {
t.Errorf("negative prompt not threaded: %v", neg["text"])
}
if !strings.HasPrefix(sent.ClientID, "imagen-") && sent.ClientID == "" {
t.Errorf("client_id should be set: %q", sent.ClientID)
}
}
func TestComfyTypeIsRegistered(t *testing.T) {
if !Default.Has(ComfyType) {
t.Errorf("comfyui type not registered in Default")
}
}

View File

@@ -0,0 +1,567 @@
package backend
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"time"
)
// ReplicateType is the type-name adapters register under for Replicate
// instances.
const ReplicateType = "replicate"
// Replicate is the Replicate API adapter. It speaks the public REST API
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
// to submit, then polls /v1/predictions/{id} until the prediction
// settles, then downloads the produced image.
//
// Concurrency: a single Replicate is safe to share across goroutines.
type Replicate struct {
instance string
apiBase string
apiToken string
tokenEnv string
model string // "owner/name" or "owner/name:version-hash"
owner string
name string
version string // optional; empty means "use the model-based predictions endpoint"
defaultSteps int
defaultAspect string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
initialBackoff time.Duration
// Sink is where successful generations are recorded for cost-tracking.
// nil means "do not record". The framework wires this up in the CLI;
// adapter tests inject a fake.
Sink UsageSink
}
// UsageSink writes one row per successful generation. Implementations
// should treat write failures as warnings, not errors — the image has
// already landed on disk; failing the call would lose the artefact.
type UsageSink interface {
Record(ctx context.Context, row UsageRow) error
}
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
// prompt itself is intentionally NOT included — only the sha256 hash.
type UsageRow struct {
Backend string
Model string
Seed *int64
PromptHash string
LatencyMs int
CostUSDEstimate *float64
Caller string
}
// NewReplicate is the registry constructor. cfg is the adapter's slice
// of imagen.yaml.
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("replicate: empty instance name")
}
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
}
owner, modelName, version, err := parseModelRef(model)
if err != nil {
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
}
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
pollTimeout := timeoutForModel(modelName)
r := &Replicate{
instance: name,
apiBase: apiBase,
tokenEnv: tokenEnv,
apiToken: os.Getenv(tokenEnv),
model: model,
owner: owner,
name: modelName,
version: version,
defaultSteps: getInt(cfg, "default_steps", 0),
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
httpClient: &http.Client{Timeout: 30 * time.Second},
pollInterval: 500 * time.Millisecond,
pollTimeout: pollTimeout,
randSeed: cryptoSeed,
initialBackoff: time.Second,
}
return r, nil
}
// Name returns the instance name from imagen.yaml.
func (r *Replicate) Name() string { return r.instance }
// Generate submits one prediction and returns the resulting PNG.
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
if r.apiToken == "" {
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
}
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
aspect := computeAspectRatio(width, height, r.defaultAspect)
steps := orDefaultInt(req.Steps, r.defaultSteps)
seed := req.Seed
if seed == 0 {
seed = r.randSeed()
}
input := map[string]any{
"prompt": req.Prompt,
"aspect_ratio": aspect,
"num_outputs": 1,
"output_format": "png",
"seed": seed,
}
if steps > 0 {
input["num_inference_steps"] = steps
}
if req.NegativePrompt != "" {
input["negative_prompt"] = req.NegativePrompt
}
maps.Copy(input, req.BackendOpts)
start := time.Now()
pred, err := r.submitPrediction(ctx, input)
if err != nil {
return nil, err
}
final, err := r.waitForCompletion(ctx, pred.ID)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgURL, err := pickFirstOutputURL(final.Output)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
if err != nil {
return nil, err
}
latencyMs := int(time.Since(start).Milliseconds())
predictTime := final.Metrics.PredictTime
costEst, costKnown := replicatePerImageUSD(r.model)
meta := map[string]any{
"backend": r.instance,
"backend_type": ReplicateType,
"model": r.model,
"model_version": final.Version,
"prediction_id": pred.ID,
"seed": seed,
"width": width,
"height": height,
"aspect_ratio": aspect,
"latency_ms": int64(latencyMs),
}
if predictTime > 0 {
meta["predict_time_seconds"] = predictTime
}
if costKnown {
meta["cost_usd_estimate"] = costEst
}
if r.Sink != nil {
row := UsageRow{
Backend: r.instance,
Model: r.model,
PromptHash: hashPrompt(req.Prompt),
LatencyMs: latencyMs,
Caller: ResolveCaller(),
}
row.Seed = new(int64)
*row.Seed = seed
if costKnown {
c := costEst
row.CostUSDEstimate = &c
}
if err := r.Sink.Record(ctx, row); err != nil {
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
}
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: mime,
Metadata: meta,
}, nil
}
// replicatePrediction is what the REST API returns under /v1/predictions
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
// declared.
type replicatePrediction struct {
ID string `json:"id"`
Status string `json:"status"`
Version string `json:"version"`
Error json.RawMessage `json:"error"`
Output json.RawMessage `json:"output"`
Metrics struct {
PredictTime float64 `json:"predict_time"`
} `json:"metrics"`
}
// submitPrediction creates a prediction and returns it. Picks the
// model-based endpoint when no version was given (recommended for
// Replicate's official models), and the legacy /v1/predictions otherwise.
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
var (
reqURL string
body []byte
err error
)
if r.version == "" {
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
body, err = json.Marshal(map[string]any{"input": input})
} else {
reqURL = r.apiBase + "/v1/predictions"
body, err = json.Marshal(map[string]any{
"version": r.version,
"input": input,
})
}
if err != nil {
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
}
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(respBody, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
}
if pred.ID == "" {
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
}
return &pred, nil
}
// waitForCompletion polls /v1/predictions/{id} until the status is a
// terminal value (succeeded, failed, canceled) or the timeout fires.
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
deadline := time.Now().Add(r.pollTimeout)
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
start := time.Now()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if time.Now().After(deadline) {
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
}
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(body, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
}
switch pred.Status {
case "succeeded":
return &pred, nil
case "failed":
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
case "canceled":
return nil, fmt.Errorf("status=canceled")
case "starting", "processing", "":
default:
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(r.pollInterval):
}
}
}
// doWithRetry executes a Replicate API request and applies the resilience
// policy: 401 surfaces a clean message naming the env var; 429 retries
// with exponential backoff up to three times; 5xx retries once; other
// errors surface unchanged.
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
const max429Retries = 3
backoff := r.initialBackoff
if backoff <= 0 {
backoff = time.Second
}
var lastErr error
for attempt := 0; ; attempt++ {
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Token "+r.apiToken)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
}
lastErr = err
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return respBody, nil
case resp.StatusCode == http.StatusUnauthorized:
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
case resp.StatusCode == http.StatusTooManyRequests:
if attempt >= max429Retries {
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
}
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
if !sleepCtx(ctx, wait) {
return nil, ctx.Err()
}
backoff *= 2
continue
case resp.StatusCode >= 500:
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
default:
_ = lastErr
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
}
}
// fetchImage downloads the rendered image from the Replicate-provided
// CDN URL. One retry on a generic network error.
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
var lastErr error
for attempt := range 2 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
if err != nil {
return nil, "", err
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, "", err
}
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
lastErr = readErr
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if resp.StatusCode >= 500 && attempt == 0 {
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
continue
}
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
}
mime := resp.Header.Get("Content-Type")
if mime == "" {
mime = "image/png"
}
return body, mime, nil
}
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
}
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
func parseModelRef(ref string) (owner, name, version string, err error) {
rest := ref
if i := strings.IndexByte(rest, ':'); i >= 0 {
version = rest[i+1:]
rest = rest[:i]
}
parts := strings.SplitN(rest, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
}
return parts[0], parts[1], version, nil
}
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
// than schnell; everything else gets the dev timeout for safety.
func timeoutForModel(name string) time.Duration {
switch strings.ToLower(name) {
case "flux-schnell":
return 60 * time.Second
default:
return 120 * time.Second
}
}
// pickFirstOutputURL extracts the first image URL from the output field.
// Replicate returns either a single string or an array of strings for image
// models — we accept both.
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", fmt.Errorf("output is empty")
}
var s string
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
return s, nil
}
var arr []string
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
return arr[0], nil
}
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
}
// computeAspectRatio reduces width:height to a Replicate-supported aspect
// ratio when the reduction lands on one of the canonical values; otherwise
// returns fallback.
func computeAspectRatio(w, h int, fallback string) string {
if w <= 0 || h <= 0 {
return fallback
}
g := gcd(w, h)
a, b := w/g, h/g
s := fmt.Sprintf("%d:%d", a, b)
if isReplicateAspectRatio(s) {
return s
}
return fallback
}
func isReplicateAspectRatio(s string) bool {
switch s {
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
return true
}
return false
}
func gcd(a, b int) int {
for b != 0 {
a, b = b, a%b
}
if a < 0 {
return -a
}
return a
}
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
// is intentionally never written to the cost-tracking table.
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// ResolveCaller returns the agent identity the cost-tracking row is
// attributed to. Order of resolution mirrors the maimcp identity logic:
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
// "unknown".
func ResolveCaller() string {
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
return v
}
if pane := os.Getenv("TMUX_PANE"); pane != "" {
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
if err == nil {
if name := strings.TrimSpace(string(out)); name != "" {
return name
}
}
}
return "unknown"
}
// backoffFor429 honours a Retry-After header (in seconds) when present
// and within reason, otherwise falls back to the caller's backoff.
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
if retryAfter == "" {
return fallback
}
d, err := time.ParseDuration(retryAfter + "s")
if err != nil || d <= 0 {
return fallback
}
if d > 30*time.Second {
return 30 * time.Second
}
return d
}
func sleepCtx(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
func bytesReader(b []byte) io.Reader {
if b == nil {
return nil
}
return bytes.NewReader(b)
}
// shortPath strips the host so error messages don't leak the API base.
func shortPath(u string) string {
if i := strings.Index(u, "/v1/"); i >= 0 {
return u[i:]
}
return u
}
func init() {
Register(ReplicateType, NewReplicate)
}

View File

@@ -0,0 +1,42 @@
package backend
import "strings"
// Replicate pricing snapshot.
//
// Source: https://replicate.com/pricing and the per-model "Run" tab on
// each model page. Replicate bills per second of GPU time, but the
// black-forest-labs FLUX models also publish a flat per-image price for
// the typical settings — that flat number is what we hardcode here.
//
// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the
// rates drift more than ~10%, update the table and bump snapshotDate.
const replicatePricingSnapshotDate = "2026-05-08"
// replicatePerImageUSD is the per-image cost estimate keyed by Replicate
// model identifier ("owner/name", with any ":version" trimmed). Returns
// the rate and true if the model is known, 0 and false otherwise — an
// unknown model writes a row with NULL cost rather than a wrong number.
func replicatePerImageUSD(model string) (float64, bool) {
key := normalisePricingKey(model)
switch key {
case "black-forest-labs/flux-schnell":
return 0.003, true
case "black-forest-labs/flux-dev":
return 0.025, true
case "black-forest-labs/flux-pro":
return 0.055, true
case "black-forest-labs/flux-1.1-pro":
return 0.040, true
}
return 0, false
}
// normalisePricingKey strips the optional ":version" suffix and lowercases
// the owner/name pair. "Owner/Name:hash" → "owner/name".
func normalisePricingKey(model string) string {
if i := strings.IndexByte(model, ':'); i >= 0 {
model = model[:i]
}
return strings.ToLower(model)
}

View File

@@ -0,0 +1,675 @@
package backend
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// fakeReplicate is a programmable mock of the Replicate REST API.
type fakeReplicate struct {
t *testing.T
mu sync.Mutex
// Responses for the predictions submission. If the request matches the
// model-based path, modelEndpointHits increments; for /v1/predictions,
// versionEndpointHits increments.
createStatus int
createBody []byte
createCalls atomic.Int32
// Auth-fail policy: if true, every request to the prediction endpoints
// returns 401 with a stock body.
auth401 bool
// 429-then-OK policy: first N create calls return 429.
create429Until int32
retryAfter string
// Sequence of /predictions/{id} responses, walked in order; once
// exhausted, the last entry is returned indefinitely.
pollResponses []string
pollIdx atomic.Int32
pollCalls atomic.Int32
// Image download policy.
imageStatus int
imageBody []byte
image5xxFirst int32
imageCalls atomic.Int32
server *httptest.Server
}
func (f *fakeReplicate) handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
f.handleCreate(w, r)
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
f.handleCreate(w, r)
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
f.handlePoll(w, r)
case r.Method == http.MethodGet && r.URL.Path == "/img":
f.handleImage(w, r)
default:
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
http.NotFound(w, r)
}
})
}
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
n := f.createCalls.Add(1)
if f.auth401 {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
return
}
if n <= f.create429Until {
if f.retryAfter != "" {
w.Header().Set("Retry-After", f.retryAfter)
}
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
return
}
w.WriteHeader(f.createStatus)
_, _ = w.Write(f.createBody)
}
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
f.pollCalls.Add(1)
idx := int(f.pollIdx.Add(1)) - 1
if idx >= len(f.pollResponses) {
idx = len(f.pollResponses) - 1
}
if idx < 0 {
http.Error(w, "no poll response configured", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(f.pollResponses[idx]))
}
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
n := f.imageCalls.Add(1)
if n <= f.image5xxFirst {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte("upstream unavailable"))
return
}
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(f.imageStatus)
_, _ = w.Write(f.imageBody)
}
func (f *fakeReplicate) start() {
f.server = httptest.NewServer(f.handler())
f.t.Cleanup(f.server.Close)
}
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
func newFakeReplicate(t *testing.T) *fakeReplicate {
t.Helper()
f := &fakeReplicate{
t: t,
createStatus: http.StatusCreated,
imageStatus: http.StatusOK,
imageBody: mustPNG(t, 8, 8),
}
f.start()
// Default happy-path responses now that the server URL is known.
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
}
return f
}
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
t.Helper()
be, err := NewReplicate("flux-test", map[string]any{
"api_token_env": "TEST_REPLICATE_TOKEN",
"model": model,
"api_base": f.server.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake-token"
r.pollInterval = time.Millisecond
r.pollTimeout = 5 * time.Second
r.initialBackoff = time.Millisecond
r.randSeed = func() int64 { return 42 }
return r
}
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
t.Errorf("expected error for empty instance name")
}
if _, err := NewReplicate("x", map[string]any{}); err == nil {
t.Errorf("expected error for missing model")
}
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
t.Errorf("expected error for malformed model")
}
}
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.apiToken = "" // simulate the env var being unset
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error when token is missing")
}
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name the env var: %v", err)
}
}
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
f := newFakeReplicate(t)
sink := &recordingSink{}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sink
res, err := r.Generate(context.Background(), Request{
Prompt: "a tiny dragon",
Width: 1024, Height: 1024,
Seed: 0,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
defer res.ImageReader.Close()
body, _ := io.ReadAll(res.ImageReader)
if !bytes.Equal(body, f.imageBody) {
t.Errorf("image body did not round-trip")
}
if mime := res.MimeType; mime != "image/png" {
t.Errorf("mime = %q", mime)
}
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
t.Errorf("metadata model = %v", got)
}
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
t.Errorf("metadata model_version = %v", got)
}
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
t.Errorf("metadata predict_time_seconds = %v", got)
}
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
}
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
t.Errorf("aspect_ratio = %q", got)
}
if got := f.pollCalls.Load(); got < 3 {
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
row := sink.rows[0]
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
}
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
}
if len(row.PromptHash) != 64 {
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
}
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
}
if row.LatencyMs <= 0 {
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
}
}
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
f := newFakeReplicate(t)
var sentBody []byte
var sentPath string
mu := sync.Mutex{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
mu.Lock()
sentBody, _ = io.ReadAll(r.Body)
sentPath = r.URL.Path
mu.Unlock()
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
default:
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev:abc123",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
if err != nil {
t.Fatalf("Generate: %v", err)
}
res.ImageReader.Close()
mu.Lock()
defer mu.Unlock()
if sentPath != "/v1/predictions" {
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
}
var body map[string]any
if err := json.Unmarshal(sentBody, &body); err != nil {
t.Fatalf("unmarshal sent body: %v", err)
}
if body["version"] != "abc123" {
t.Errorf("expected version=abc123 in body, got %v", body["version"])
}
}
func TestReplicate401SurfacesEnvHint(t *testing.T) {
f := newFakeReplicate(t)
f.auth401 = true
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected 401 to surface as error")
}
msg := err.Error()
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name env var: %v", err)
}
if !strings.Contains(msg, "401") {
t.Errorf("error should mention 401: %v", err)
}
}
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 2 // first two calls 429 then succeed
f.retryAfter = "" // force the adapter's exp-backoff path
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = time.Millisecond
// Squash the backoff so the test is fast.
r.httpClient = &http.Client{Timeout: 5 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := r.Generate(ctx, Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate after 429s: %v", err)
}
res.ImageReader.Close()
if got := f.createCalls.Load(); got != 3 {
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
}
}
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 99 // every call 429
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained 429s")
}
if !strings.Contains(err.Error(), "429") {
t.Errorf("expected 429 in error: %v", err)
}
// max429Retries=3 → 1 initial + 3 retries = 4 total
if got := f.createCalls.Load(); got != 4 {
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
}
}
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error for failed prediction")
}
if !strings.Contains(err.Error(), "failed") {
t.Errorf("error should mention failure: %v", err)
}
if !strings.Contains(err.Error(), "NSFW") {
t.Errorf("error should include the API error message: %v", err)
}
}
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
f := newFakeReplicate(t)
// Always return processing → adapter times out.
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 30 * time.Millisecond
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected timeout error")
}
if !strings.Contains(err.Error(), "did not complete") {
t.Errorf("expected 'did not complete' in error, got %v", err)
}
if !strings.Contains(err.Error(), "waited") {
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
}
}
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 1 // first download 502, second OK
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (download retry): %v", err)
}
res.ImageReader.Close()
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
}
}
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 99 // every download fails
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained image-download 5xx")
}
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
}
}
func TestReplicateContextCancelStopsPolling(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected ctx error")
}
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
t.Errorf("expected deadline exceeded, got %v", err)
}
}
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-schnell",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
// captured input should still have been recorded.
_, _ = r.Generate(context.Background(), Request{
Prompt: "p",
Width: 1024, Height: 1024,
BackendOpts: map[string]any{
"output_quality": 90,
"go_fast": true,
},
})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["output_quality"] != float64(90) {
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
}
if captured["go_fast"] != true {
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
}
if captured["prompt"] != "p" {
t.Errorf("prompt not threaded: %v", captured["prompt"])
}
}
func TestReplicateOutputArrayAccepted(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (output as array): %v", err)
}
res.ImageReader.Close()
}
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "stability-ai/sdxl")
sink := &recordingSink{}
r.Sink = sink
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (unknown model): %v", err)
}
res.ImageReader.Close()
if _, present := res.Metadata["cost_usd_estimate"]; present {
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
if sink.rows[0].CostUSDEstimate != nil {
t.Errorf("expected nil cost in sink row for unknown model")
}
}
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("sink failure should not fail Generate: %v", err)
}
res.ImageReader.Close()
}
func TestReplicateDefaultStepsApplied(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev",
"api_base": srv.URL,
"default_steps": 28,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["num_inference_steps"] != float64(28) {
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
}
}
func TestComputeAspectRatio(t *testing.T) {
cases := []struct {
w, h int
fallback string
want string
}{
{1024, 1024, "1:1", "1:1"},
{1920, 1080, "1:1", "16:9"},
{2560, 1440, "1:1", "16:9"},
{1024, 768, "1:1", "4:3"},
{1024, 1280, "1:1", "4:5"},
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
{0, 1024, "1:1", "1:1"},
}
for _, c := range cases {
got := computeAspectRatio(c.w, c.h, c.fallback)
if got != c.want {
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
}
}
}
func TestParseModelRef(t *testing.T) {
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
owner, name, ver, err = parseModelRef("owner/name:hash123")
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
if _, _, _, err := parseModelRef("noslash"); err == nil {
t.Errorf("expected error for malformed ref")
}
}
func TestHashPromptStable(t *testing.T) {
a := hashPrompt("hello")
b := hashPrompt("hello")
c := hashPrompt("hello!")
if a != b {
t.Errorf("hashPrompt should be deterministic")
}
if a == c {
t.Errorf("different prompts should hash differently")
}
if len(a) != 64 {
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
}
}
func TestReplicatePricingKnownModels(t *testing.T) {
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
t.Errorf("dev rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
}
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
t.Errorf("unknown model should report ok=false")
}
}
func TestReplicateTypeIsRegistered(t *testing.T) {
if !Default.Has(ReplicateType) {
t.Errorf("replicate type not registered in Default")
}
}
// recordingSink captures rows for assertion.
type recordingSink struct {
mu sync.Mutex
rows []UsageRow
}
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rows = append(s.rows, row)
return nil
}
type sinkFunc func(context.Context, UsageRow) error
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }

356
internal/cloud/cloud.go Normal file
View File

@@ -0,0 +1,356 @@
// Package cloud syncs a generated image to Supabase Storage and inserts
// a row into imagen.images. Both steps are best-effort: callers log the
// returned error and proceed, because the local PNG + sidecar are already
// on disk by the time Sync runs and a cloud blip should not lose the
// artefact.
//
// The single source of truth for the row schema is the imagen_schema_init
// migration — see internal docs in the issue body for #7.
package cloud
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
)
// supabaseSchema is the PostgREST profile header value the imagen schema
// is exposed under (see ALTER ROLE authenticator SET pgrst.db_schemas).
const supabaseSchema = "imagen"
// bucketName is the Supabase Storage bucket all generated images land in.
const bucketName = "imagen-generated"
// Sink writes one PNG + one row per generation. It is safe to share
// across goroutines.
type Sink struct {
// URL is SUPABASE_URL — e.g. https://supa.flexsiebels.de.
URL string
// APIKey is the service-role key (SUPABASE_SERVICE_KEY). Storage uploads
// and DB inserts both bypass RLS with this key — the policies on the
// table + bucket are the contract for the read side.
APIKey string
// OwnerUserID is m's auth.users.id. It populates owner_user_id on every
// row. Empty means the sink refuses to insert (the column is NOT NULL
// and the user-mode reader needs it for the RLS policy).
OwnerUserID string
// HTTP is the http client; tests inject one pointing at httptest.
HTTP *http.Client
// MaxRetries is the number of additional attempts after the first
// failure for retryable (5xx) responses. Zero means single-shot.
MaxRetries int
// InitialBackoff is the wait before the first retry; doubles per attempt.
// Set very small in tests.
InitialBackoff time.Duration
}
// NewFromEnv returns a sink populated from SUPABASE_URL +
// SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) + IMAGEN_OWNER_USER_ID.
// Returns ok=false if the URL or key are missing — the caller treats that
// as "cloud-sync disabled by environment".
func NewFromEnv() (*Sink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &Sink{
URL: u,
APIKey: key,
OwnerUserID: os.Getenv("IMAGEN_OWNER_USER_ID"),
HTTP: &http.Client{Timeout: 30 * time.Second},
MaxRetries: 2,
InitialBackoff: time.Second,
}, true
}
// SyncRequest is the cross-backend ingredient set Sync needs. Date is
// formatted as YYYY-MM-DD; Slug + Seed are reused from the local
// filename so storage_path mirrors disk layout.
type SyncRequest struct {
Date string
Slug string
Seed int64
Ext string // "png", "jpg", "webp" — no leading dot
PNG []byte
MimeType string
Prompt string
Backend string
Model string
Steps int
Width int
Height int
LatencyMs int
CostUSDEstimate *float64
Sidecar map[string]any
}
// SyncResult tells the caller what landed where.
type SyncResult struct {
StoragePath string // e.g. "2026-05-11/lighthouse-42.png"
ImageID string // imagen.images.id (UUID)
}
// Sync uploads the bytes and inserts the metadata row. Returns the row's
// id and storage_path on success; any non-nil error is what the caller
// surfaces as "imagen: cloud sync: <err>" and otherwise ignores.
func (s *Sink) Sync(ctx context.Context, req SyncRequest) (*SyncResult, error) {
if s == nil {
return nil, fmt.Errorf("cloud sink not configured")
}
if s.OwnerUserID == "" {
return nil, fmt.Errorf("owner_user_id not set (config or $IMAGEN_OWNER_USER_ID); refusing to insert NULL into imagen.images")
}
if req.Date == "" || req.Slug == "" {
return nil, fmt.Errorf("date and slug are required for storage_path")
}
ext := req.Ext
if ext == "" {
ext = "png"
}
storagePath := fmt.Sprintf("%s/%s-%d.%s", req.Date, req.Slug, req.Seed, ext)
if err := s.upload(ctx, storagePath, req.PNG, req.MimeType); err != nil {
return nil, fmt.Errorf("storage upload: %w", err)
}
id, err := s.insertRow(ctx, storagePath, req)
if err != nil {
return &SyncResult{StoragePath: storagePath}, fmt.Errorf("db insert: %w", err)
}
return &SyncResult{StoragePath: storagePath, ImageID: id}, nil
}
// upload PUTs the PNG into the imagen-generated bucket. We use
// Content-Type so signed URLs render in the browser without a download
// prompt. POST would error on second-write; PUT (with x-upsert: true) is
// idempotent for re-runs of the same date+slug+seed.
func (s *Sink) upload(ctx context.Context, storagePath string, body []byte, mime string) error {
if mime == "" {
mime = "image/png"
}
endpoint := fmt.Sprintf("%s/storage/v1/object/%s/%s", s.URL, bucketName, pathEscape(storagePath))
return s.doRetry(ctx, func(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", mime)
req.Header.Set("x-upsert", "true")
return s.HTTP.Do(req)
})
}
// insertRow POSTs to PostgREST against the imagen schema. Prefer:
// return=representation gives us the inserted id back without a second
// round-trip.
func (s *Sink) insertRow(ctx context.Context, storagePath string, req SyncRequest) (string, error) {
row := map[string]any{
"owner_user_id": s.OwnerUserID,
"prompt": req.Prompt,
"prompt_hash": hashPrompt(req.Prompt),
"backend": req.Backend,
"storage_path": storagePath,
}
if req.Model != "" {
row["model"] = req.Model
}
if req.Seed != 0 {
row["seed"] = req.Seed
}
if req.Steps != 0 {
row["steps"] = req.Steps
}
if req.Width != 0 {
row["width"] = req.Width
}
if req.Height != 0 {
row["height"] = req.Height
}
if req.LatencyMs != 0 {
row["latency_ms"] = req.LatencyMs
}
if req.CostUSDEstimate != nil {
row["cost_usd_estimate"] = *req.CostUSDEstimate
}
if len(req.Sidecar) > 0 {
row["sidecar"] = req.Sidecar
}
body, err := json.Marshal(row)
if err != nil {
return "", fmt.Errorf("marshal row: %w", err)
}
endpoint := s.URL + "/rest/v1/images"
respBody, err := s.doRetryRead(ctx, func(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=representation")
return s.HTTP.Do(req)
})
if err != nil {
return "", err
}
var rows []struct {
ID string `json:"id"`
}
if err := json.Unmarshal(respBody, &rows); err != nil {
return "", fmt.Errorf("parse insert response: %w (body: %s)", err, snip(respBody))
}
if len(rows) == 0 {
return "", fmt.Errorf("insert returned 0 rows (body: %s)", snip(respBody))
}
return rows[0].ID, nil
}
// SignedURL asks the Storage API for a time-limited URL. ttlSeconds is
// the validity window. Returned URL is host-qualified and ready to hand
// to a browser.
func (s *Sink) SignedURL(ctx context.Context, storagePath string, ttlSeconds int) (string, error) {
if s == nil {
return "", fmt.Errorf("cloud sink not configured")
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
endpoint := fmt.Sprintf("%s/storage/v1/object/sign/%s/%s", s.URL, bucketName, pathEscape(storagePath))
body, err := json.Marshal(map[string]any{"expiresIn": ttlSeconds})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.HTTP.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("sign %d: %s", resp.StatusCode, snip(respBody))
}
var parsed struct {
SignedURL string `json:"signedURL"`
}
if err := json.Unmarshal(respBody, &parsed); err != nil {
return "", fmt.Errorf("parse sign response: %w (body: %s)", err, snip(respBody))
}
if parsed.SignedURL == "" {
return "", fmt.Errorf("empty signedURL in response: %s", snip(respBody))
}
full := parsed.SignedURL
if strings.HasPrefix(full, "/") {
full = s.URL + full
}
return full, nil
}
// doRetry runs op up to MaxRetries+1 times. 5xx and transport errors are
// retried with exponential backoff; 4xx surfaces immediately as a
// permanent error (caller's bug in the row, not a network blip).
func (s *Sink) doRetry(ctx context.Context, op func(context.Context) (*http.Response, error)) error {
_, err := s.doRetryRead(ctx, op)
return err
}
// doRetryRead is the read-the-body variant. Returns the 2xx response
// body bytes; non-2xx is wrapped in an error. Same retry semantics as
// doRetry: 5xx/transport retries with exponential backoff, 4xx is fatal.
func (s *Sink) doRetryRead(ctx context.Context, op func(context.Context) (*http.Response, error)) ([]byte, error) {
backoff := s.InitialBackoff
if backoff == 0 {
backoff = time.Second
}
attempts := s.MaxRetries + 1
if attempts < 1 {
attempts = 1
}
var lastErr error
for i := 0; i < attempts; i++ {
if i > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(backoff):
}
backoff *= 2
}
resp, err := op(ctx)
if err != nil {
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
resp.Body.Close()
if readErr != nil {
lastErr = fmt.Errorf("read body: %w", readErr)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return body, nil
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
}
return nil, lastErr
}
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// pathEscape encodes each path segment but keeps the slashes — the
// Storage API treats the part after the bucket name as a virtual file
// path with directory separators.
func pathEscape(p string) string {
parts := strings.Split(p, "/")
for i, seg := range parts {
parts[i] = url.PathEscape(seg)
}
return path.Join(parts...)
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}

View File

@@ -0,0 +1,326 @@
package cloud
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"
)
// fakeSupabase is a tiny stand-in for Supabase Storage + PostgREST. It
// records what came in and returns canned responses based on path.
type fakeSupabase struct {
t *testing.T
mux *http.ServeMux
server *httptest.Server
uploadCalls int32
insertCalls int32
uploadBytes []byte
uploadHdr http.Header
insertBody []byte
insertHdr http.Header
}
func newFakeSupabase(t *testing.T, opts ...func(*fakeSupabase)) *fakeSupabase {
f := &fakeSupabase{t: t}
f.mux = http.NewServeMux()
// Storage upload — anything under /storage/v1/object/<bucket>/...
f.mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&f.uploadCalls, 1)
body, _ := io.ReadAll(r.Body)
f.uploadBytes = body
f.uploadHdr = r.Header.Clone()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"Key":"imagen-generated/somepath"}`))
})
// Storage sign URL
f.mux.HandleFunc("/storage/v1/object/sign/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"signedURL":"/storage/v1/object/sign/imagen-generated/some.png?token=abc"}`))
})
// PostgREST insert
f.mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&f.insertCalls, 1)
body, _ := io.ReadAll(r.Body)
f.insertBody = body
f.insertHdr = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`[{"id":"00000000-0000-0000-0000-000000000abc"}]`))
})
for _, opt := range opts {
opt(f)
}
f.server = httptest.NewServer(f.mux)
t.Cleanup(f.server.Close)
return f
}
func newSink(server *httptest.Server) *Sink {
return &Sink{
URL: server.URL,
APIKey: "fake-service-key",
OwnerUserID: "00000000-0000-0000-0000-000000000001",
HTTP: server.Client(),
MaxRetries: 2,
InitialBackoff: time.Millisecond,
}
}
func TestSyncHappyPath(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
cost := 0.003
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11",
Slug: "lighthouse",
Seed: 42,
Ext: "png",
PNG: []byte("PNGbytes"),
MimeType: "image/png",
Prompt: "a tiny lighthouse on a stormy cliff",
Backend: "flux-schnell-local",
Model: "flux1-schnell",
Steps: 4,
Width: 1024,
Height: 1024,
LatencyMs: 1500,
CostUSDEstimate: &cost,
Sidecar: map[string]any{
"timestamp": "2026-05-11T01:30:00Z",
"backend": "flux-schnell-local",
},
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
if res.StoragePath != "2026-05-11/lighthouse-42.png" {
t.Errorf("storage_path = %q", res.StoragePath)
}
if res.ImageID != "00000000-0000-0000-0000-000000000abc" {
t.Errorf("image_id = %q", res.ImageID)
}
if got := atomic.LoadInt32(&f.uploadCalls); got != 1 {
t.Errorf("upload calls = %d, want 1", got)
}
if got := atomic.LoadInt32(&f.insertCalls); got != 1 {
t.Errorf("insert calls = %d, want 1", got)
}
if !bytes.Equal(f.uploadBytes, []byte("PNGbytes")) {
t.Errorf("uploaded bytes = %q", f.uploadBytes)
}
// Verify the row payload carries the prompt + computed hash + non-zero
// metadata. Empty fields should be omitted from the JSON body so RLS
// won't see surprise keys.
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("insert body parse: %v\n%s", err, f.insertBody)
}
if row["prompt"] != "a tiny lighthouse on a stormy cliff" {
t.Errorf("row.prompt = %v", row["prompt"])
}
if row["owner_user_id"] != "00000000-0000-0000-0000-000000000001" {
t.Errorf("row.owner_user_id = %v", row["owner_user_id"])
}
if row["storage_path"] != "2026-05-11/lighthouse-42.png" {
t.Errorf("row.storage_path = %v", row["storage_path"])
}
hash, _ := row["prompt_hash"].(string)
if len(hash) != 64 {
t.Errorf("prompt_hash should be 64-char sha256 hex, got %q", hash)
}
if row["backend"] != "flux-schnell-local" {
t.Errorf("row.backend = %v", row["backend"])
}
if row["seed"].(float64) != 42 {
t.Errorf("row.seed = %v", row["seed"])
}
if row["latency_ms"].(float64) != 1500 {
t.Errorf("row.latency_ms = %v", row["latency_ms"])
}
if row["cost_usd_estimate"].(float64) != 0.003 {
t.Errorf("row.cost = %v", row["cost_usd_estimate"])
}
if row["sidecar"] == nil {
t.Errorf("row.sidecar missing")
}
// PostgREST schema headers — hardcoded to "imagen".
if got := f.insertHdr.Get("Accept-Profile"); got != "imagen" {
t.Errorf("Accept-Profile = %q", got)
}
if got := f.insertHdr.Get("Content-Profile"); got != "imagen" {
t.Errorf("Content-Profile = %q", got)
}
if got := f.insertHdr.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") {
t.Errorf("Authorization = %q", got)
}
// Storage upsert should be set so re-runs of the same date+slug+seed
// don't fail with 409.
if got := f.uploadHdr.Get("x-upsert"); got != "true" {
t.Errorf("x-upsert = %q", got)
}
}
func TestSyncRetryOn5xx(t *testing.T) {
var uploadAttempts int32
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&uploadAttempts, 1)
// Two 503s, then OK.
if n < 3 {
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`[{"id":"row-id"}]`))
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err != nil {
t.Fatalf("Sync (with retry): %v", err)
}
if got := atomic.LoadInt32(&uploadAttempts); got != 3 {
t.Errorf("upload attempts = %d, want 3", got)
}
if res.ImageID != "row-id" {
t.Errorf("image_id = %q", res.ImageID)
}
}
func TestSyncNoRetryOn4xx(t *testing.T) {
var uploadAttempts int32
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&uploadAttempts, 1)
http.Error(w, `{"message":"bad request"}`, http.StatusBadRequest)
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error on 400")
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("error should mention 400 status: %v", err)
}
if got := atomic.LoadInt32(&uploadAttempts); got != 1 {
t.Errorf("upload attempts = %d, want 1 (no retry on 4xx)", got)
}
}
func TestSyncMissingOwnerUserID(t *testing.T) {
srv := httptest.NewServer(http.NewServeMux())
defer srv.Close()
s := &Sink{
URL: srv.URL,
APIKey: "k",
// OwnerUserID intentionally empty.
HTTP: srv.Client(),
InitialBackoff: time.Millisecond,
}
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error when owner_user_id unset")
}
if !strings.Contains(err.Error(), "owner_user_id") {
t.Errorf("error should mention owner_user_id: %v", err)
}
}
func TestSyncRequiresDateAndSlug(t *testing.T) {
srv := httptest.NewServer(http.NewServeMux())
defer srv.Close()
s := newSink(srv)
_, err := s.Sync(context.Background(), SyncRequest{
Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error for missing date")
}
}
func TestSignedURL(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
got, err := s.SignedURL(context.Background(), "2026-05-11/x.png", 60)
if err != nil {
t.Fatalf("SignedURL: %v", err)
}
want := f.server.URL + "/storage/v1/object/sign/imagen-generated/some.png?token=abc"
if got != want {
t.Errorf("signed URL = %q, want %q", got, want)
}
}
func TestSyncDBFailureSurfacesPathOnError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "schema cache miss", http.StatusInternalServerError)
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 9, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error from DB insert failure")
}
// Storage upload succeeded — caller can still see the upload landed.
if res == nil || res.StoragePath != "2026-05-11/x-9.png" {
t.Errorf("expected storage_path on partial success, got %+v", res)
}
}
func TestPathEscape(t *testing.T) {
cases := map[string]string{
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",
"2026-05-11/two words.png": "2026-05-11/two%20words.png",
"with#hash/and?query.png": "with%23hash/and%3Fquery.png",
}
for in, want := range cases {
got := pathEscape(in)
if got != want {
t.Errorf("pathEscape(%q) = %q, want %q", in, got, want)
}
// Sanity: every part should round-trip via url.PathUnescape.
for _, seg := range strings.Split(got, "/") {
if _, err := url.PathUnescape(seg); err != nil {
t.Errorf("segment %q failed unescape: %v", seg, err)
}
}
}
}

View File

@@ -15,15 +15,29 @@ import (
// Config is the top-level shape of imagen.yaml.
type Config struct {
DefaultBackend string `yaml:"default_backend"`
// OwnerUserID is m's auth.users.id on msupabase. The cloud-sync writer
// uses it to populate imagen.images.owner_user_id (NOT NULL, owns RLS).
// Empty disables DB inserts even when cloud_sync is on.
OwnerUserID string `yaml:"owner_user_id"`
Output OutputConfig `yaml:"output"`
Backends map[string]BackendSpec `yaml:"backends"`
}
// OutputConfig controls where generated images and metadata sidecars land.
// OutputConfig controls where generated images and metadata sidecars land,
// and whether `imagen generate` opens a tmux preview window.
type OutputConfig struct {
Directory string `yaml:"directory"`
Naming string `yaml:"naming"`
WriteMetadataJSON bool `yaml:"write_metadata_json"`
// Preview is the tri-state preview mode: "auto" (default), "on", "off".
// Empty / unset is treated as "auto". $IMAGEN_PREVIEW and the
// --preview/--no-preview flags override this in turn.
Preview string `yaml:"preview"`
// CloudSync controls whether successful generations also upload to
// Supabase Storage and insert into imagen.images. Tri-state mirroring
// Preview: "auto" (default — on when SUPABASE_URL + SUPABASE_SERVICE_KEY
// are set), "on" (errors if env unset), "off". --no-cloud overrides.
CloudSync string `yaml:"cloud_sync"`
}
// BackendSpec is one entry under `backends:`. Type identifies the adapter;
@@ -78,6 +92,16 @@ func (c *Config) Validate() error {
return fmt.Errorf("default_backend %q is not defined under backends:", c.DefaultBackend)
}
}
switch c.Output.Preview {
case "", "auto", "on", "off":
default:
return fmt.Errorf("output.preview = %q (must be auto|on|off)", c.Output.Preview)
}
switch c.Output.CloudSync {
case "", "auto", "on", "off":
default:
return fmt.Errorf("output.cloud_sync = %q (must be auto|on|off)", c.Output.CloudSync)
}
for name, spec := range c.Backends {
if name == "" {
return errors.New("empty backend name")
@@ -95,28 +119,57 @@ const Sample = `# imagen.yaml — config for the imagen CLI.
# implementing the Backend interface, registering its type name, and listing
# an instance here.
default_backend: mock
default_backend: flux-schnell-local
# Owner UUID for the cloud-sync row in imagen.images. Look up via:
# SELECT id FROM auth.users WHERE email = '<your-supabase-email>';
# Empty disables imagen.images inserts even when cloud_sync is on.
owner_user_id: ""
output:
directory: ~/Pictures/imagen
naming: "{date}-{slug}-{seed}.png"
write_metadata_json: true
# Open a tmux window with tmux-img after a successful generation.
# auto (default): preview iff stdout is a TTY and $TMUX is set.
# on: always preview (errors outside a tmux session).
# off: never preview (use this for batch / CI callers).
preview: auto
# Sync the PNG to Supabase Storage (bucket: imagen-generated) and insert
# a row into imagen.images. Reads SUPABASE_URL + SUPABASE_SERVICE_KEY
# from env (same as mai.imagen_usage cost-tracking).
# auto (default): on iff env is configured AND owner_user_id is set.
# on: always upload (errors if env or owner_user_id is missing).
# off: never upload. --no-cloud also forces off per-call.
cloud_sync: auto
backends:
mock:
type: mock
flux-schnell-local:
type: comfyui
base_url: http://mrock:8188
# Filename of the unet checkpoint inside the ComfyUI server's
# models/unet/ directory. See docs/setup-comfyui-mrock.md.
model: flux1-schnell.safetensors
default_steps: 4
default_sampler: euler
default_scheduler: simple
mock:
type: mock
flux-schnell-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-schnell
default_steps: 4
default_aspect_ratio: "1:1"
flux-dev-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-dev
default_steps: 28
default_aspect_ratio: "1:1"
dalle3:
type: openai

View File

@@ -16,7 +16,7 @@ func TestLoadAndValidate(t *testing.T) {
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.DefaultBackend != "mock" {
if cfg.DefaultBackend != "flux-schnell-local" {
t.Errorf("default = %q", cfg.DefaultBackend)
}
mock, ok := cfg.Backends["mock"]
@@ -30,9 +30,15 @@ func TestLoadAndValidate(t *testing.T) {
if !ok {
t.Fatalf("flux backend missing")
}
if flux.Type != "comfyui" {
t.Errorf("flux type = %q", flux.Type)
}
if flux.Raw["base_url"] != "http://mrock:8188" {
t.Errorf("flux base_url = %v", flux.Raw["base_url"])
}
if flux.Raw["model"] != "flux1-schnell.safetensors" {
t.Errorf("flux model = %v", flux.Raw["model"])
}
}
func TestValidateRejectsUnknownDefault(t *testing.T) {
@@ -54,6 +60,67 @@ func TestValidateRejectsMissingType(t *testing.T) {
}
}
func TestValidatePreviewMode(t *testing.T) {
for _, mode := range []string{"", "auto", "on", "off"} {
c := &Config{Output: OutputConfig{Preview: mode}}
if err := c.Validate(); err != nil {
t.Errorf("preview=%q: unexpected error %v", mode, err)
}
}
bad := &Config{Output: OutputConfig{Preview: "yes"}}
if err := bad.Validate(); err == nil {
t.Errorf("expected error for invalid preview value")
}
}
func TestValidateCloudSyncMode(t *testing.T) {
for _, mode := range []string{"", "auto", "on", "off"} {
c := &Config{Output: OutputConfig{CloudSync: mode}}
if err := c.Validate(); err != nil {
t.Errorf("cloud_sync=%q: unexpected error %v", mode, err)
}
}
bad := &Config{Output: OutputConfig{CloudSync: "yes"}}
if err := bad.Validate(); err == nil {
t.Errorf("expected error for invalid cloud_sync value")
}
}
func TestSampleParsesCloudSyncAuto(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "imagen.yaml")
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
t.Fatalf("write sample: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.Output.CloudSync != "auto" {
t.Errorf("Output.CloudSync = %q, want auto", cfg.Output.CloudSync)
}
// owner_user_id is intentionally empty in the sample — operators fill
// it in after looking up their auth.users.id.
if cfg.OwnerUserID != "" {
t.Errorf("Sample OwnerUserID should be empty, got %q", cfg.OwnerUserID)
}
}
func TestSampleParsesPreviewAuto(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "imagen.yaml")
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
t.Fatalf("write sample: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.Output.Preview != "auto" {
t.Errorf("Output.Preview = %q, want auto", cfg.Output.Preview)
}
}
func TestExpandPath(t *testing.T) {
home, _ := os.UserHomeDir()
cases := map[string]string{

View File

@@ -35,6 +35,13 @@ type Inputs struct {
type Outputs struct {
ImagePath string
SidecarPath string
// Date is the YYYY-MM-DD the writer used for the filename. Cloud sync
// reuses this so storage_path matches the local filename's date.
Date string
// Slug is the filename-safe prompt fragment the writer used.
Slug string
// Seed is the seed value baked into the filename.
Seed int64
}
// Write streams img to disk and, if enabled, writes a sidecar. The image
@@ -50,10 +57,12 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
if tmpl == "" {
tmpl = "{date}-{slug}-{seed}.{ext}"
}
date := now.Format("2006-01-02")
slug := Slug(in.Prompt)
name := renderTemplate(tmpl, map[string]string{
"date": now.Format("2006-01-02"),
"date": date,
"time": now.Format("150405"),
"slug": Slug(in.Prompt),
"slug": slug,
"seed": fmt.Sprintf("%d", in.Seed),
"backend": in.Backend,
"ext": strings.TrimPrefix(ext, "."),
@@ -80,7 +89,7 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
return nil, fmt.Errorf("close %s: %w", imagePath, err)
}
out := &Outputs{ImagePath: imagePath}
out := &Outputs{ImagePath: imagePath, Date: date, Slug: slug, Seed: in.Seed}
if w.WriteSidecar {
sidecar := imagePath + ".json"
@@ -122,7 +131,7 @@ func (w *Writer) WriteToPath(img io.Reader, path string, in Inputs) (*Outputs, e
if err := f.Close(); err != nil {
return nil, fmt.Errorf("close %s: %w", path, err)
}
out := &Outputs{ImagePath: path}
out := &Outputs{ImagePath: path, Date: now.Format("2006-01-02"), Slug: Slug(in.Prompt), Seed: in.Seed}
if w.WriteSidecar {
sidecar := path + ".json"
body := map[string]any{

119
internal/preview/tmux.go Normal file
View File

@@ -0,0 +1,119 @@
// Package preview opens a tmux window showing a generated image via tmux-img.
// Mode resolution and the actual spawn are kept separate so the CLI can
// decide-then-act and tests can drive each half independently.
package preview
import (
"errors"
"fmt"
"os/exec"
"strings"
)
// Mode is the tri-state preview setting: auto (default), on (force), off.
type Mode string
const (
ModeAuto Mode = "auto"
ModeOn Mode = "on"
ModeOff Mode = "off"
)
// ParseMode normalises a string into a Mode. Empty parses to ModeAuto so
// callers can pass through unset config / env values.
func ParseMode(s string) (Mode, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "auto":
return ModeAuto, nil
case "on":
return ModeOn, nil
case "off":
return ModeOff, nil
}
return "", fmt.Errorf("invalid preview mode %q (auto|on|off)", s)
}
// Decision is the answer to "should we preview, and why".
type Decision struct {
ShouldPreview bool
Reason string
}
// Resolve maps (mode, runtime context) to a Decision.
//
// - off -> never preview
// - on -> preview, but error if not in tmux (forced on outside tmux)
// - auto -> preview iff inTmux && stdoutTTY
func Resolve(mode Mode, inTmux, stdoutTTY bool) (Decision, error) {
switch mode {
case ModeOff:
return Decision{ShouldPreview: false, Reason: "preview=off"}, nil
case ModeOn:
if !inTmux {
return Decision{}, ErrNoTmuxForced
}
return Decision{ShouldPreview: true, Reason: "preview=on"}, nil
case ModeAuto, "":
if !inTmux {
return Decision{ShouldPreview: false, Reason: "auto: $TMUX unset"}, nil
}
if !stdoutTTY {
return Decision{ShouldPreview: false, Reason: "auto: stdout not a tty"}, nil
}
return Decision{ShouldPreview: true, Reason: "auto"}, nil
}
return Decision{}, fmt.Errorf("invalid preview mode %q", mode)
}
// Errors returned by Spawn and Resolve. Each names the missing piece and,
// where relevant, where to install it.
var (
ErrTmuxMissing = errors.New("tmux: binary not found on $PATH (required for image preview)")
ErrTmuxImgMissing = errors.New("tmux-img: binary not found on $PATH (install at ~/.local/bin/tmux-img)")
ErrNoTmuxForced = errors.New("--preview requires $TMUX (are you in a tmux session?)")
)
// Spawner spawns the tmux preview window. The exec.LookPath / cmd.Run hooks
// exist so tests can inject fakes without touching $PATH.
type Spawner struct {
LookPath func(string) (string, error)
Run func(*exec.Cmd) error
}
// Spawn opens a new tmux window named img:<slug> running tmux-img --hold
// <imagePath>. -d keeps focus in the current pane. Caller is expected to
// have already verified that we are inside a tmux session.
func (s *Spawner) Spawn(imagePath, slug string) error {
look := s.LookPath
if look == nil {
look = exec.LookPath
}
run := s.Run
if run == nil {
run = func(c *exec.Cmd) error { return c.Run() }
}
tmuxBin, err := look("tmux")
if err != nil {
return ErrTmuxMissing
}
tmuxImgBin, err := look("tmux-img")
if err != nil {
return ErrTmuxImgMissing
}
name := "img:" + slug
shellCmd := fmt.Sprintf("%s --hold %s",
shellQuote(tmuxImgBin), shellQuote(imagePath))
cmd := exec.Command(tmuxBin, "new-window", "-d", "-n", name, shellCmd)
if err := run(cmd); err != nil {
return fmt.Errorf("tmux new-window: %w", err)
}
return nil
}
// shellQuote single-quotes s for /bin/sh — tmux passes the trailing arg of
// new-window through a shell.
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}

View File

@@ -0,0 +1,170 @@
package preview
import (
"errors"
"os/exec"
"strings"
"testing"
)
func TestParseMode(t *testing.T) {
cases := map[string]Mode{
"": ModeAuto,
"auto": ModeAuto,
"AUTO": ModeAuto,
"on": ModeOn,
" on ": ModeOn,
"off": ModeOff,
}
for in, want := range cases {
got, err := ParseMode(in)
if err != nil {
t.Errorf("ParseMode(%q) err = %v", in, err)
continue
}
if got != want {
t.Errorf("ParseMode(%q) = %q, want %q", in, got, want)
}
}
if _, err := ParseMode("nope"); err == nil {
t.Errorf("ParseMode(nope) should have errored")
}
}
func TestResolve(t *testing.T) {
type tc struct {
mode Mode
inTmux bool
stdoutTTY bool
want bool
wantErr error
}
cases := map[string]tc{
"off-anywhere": {ModeOff, false, false, false, nil},
"off-in-tmux-tty": {ModeOff, true, true, false, nil},
"on-in-tmux": {ModeOn, true, false, true, nil},
"on-outside-tmux-errs": {ModeOn, false, true, false, ErrNoTmuxForced},
"auto-no-tmux": {ModeAuto, false, true, false, nil},
"auto-tmux-no-tty": {ModeAuto, true, false, false, nil},
"auto-tmux-and-tty": {ModeAuto, true, true, true, nil},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
d, err := Resolve(c.mode, c.inTmux, c.stdoutTTY)
if c.wantErr != nil {
if !errors.Is(err, c.wantErr) {
t.Fatalf("err = %v, want %v", err, c.wantErr)
}
return
}
if err != nil {
t.Fatalf("err = %v", err)
}
if d.ShouldPreview != c.want {
t.Errorf("ShouldPreview = %v, want %v (reason: %s)", d.ShouldPreview, c.want, d.Reason)
}
})
}
}
func TestSpawn_BuildsCorrectCommand(t *testing.T) {
var captured *exec.Cmd
s := &Spawner{
LookPath: func(name string) (string, error) {
switch name {
case "tmux":
return "/usr/bin/tmux", nil
case "tmux-img":
return "/home/m/.local/bin/tmux-img", nil
}
return "", exec.ErrNotFound
},
Run: func(c *exec.Cmd) error {
captured = c
return nil
},
}
if err := s.Spawn("/tmp/imagen/cat.png", "cat-in-a-fishbowl"); err != nil {
t.Fatalf("Spawn: %v", err)
}
if captured == nil {
t.Fatal("Run was not called")
}
if captured.Path != "/usr/bin/tmux" {
t.Errorf("Path = %q, want /usr/bin/tmux", captured.Path)
}
args := captured.Args
if len(args) < 6 {
t.Fatalf("args = %v (need at least 6)", args)
}
// tmux new-window -d -n img:<slug> '<shell-cmd>'
if args[1] != "new-window" {
t.Errorf("args[1] = %q, want new-window", args[1])
}
if args[2] != "-d" {
t.Errorf("args[2] = %q, want -d", args[2])
}
if args[3] != "-n" {
t.Errorf("args[3] = %q, want -n", args[3])
}
if args[4] != "img:cat-in-a-fishbowl" {
t.Errorf("args[4] = %q, want img:cat-in-a-fishbowl", args[4])
}
shellCmd := args[5]
if !strings.Contains(shellCmd, "tmux-img") || !strings.Contains(shellCmd, "--hold") || !strings.Contains(shellCmd, "/tmp/imagen/cat.png") {
t.Errorf("shell cmd %q missing expected pieces", shellCmd)
}
}
func TestSpawn_PathWithSpacesAndQuotes(t *testing.T) {
var captured *exec.Cmd
s := &Spawner{
LookPath: func(name string) (string, error) {
if name == "tmux" {
return "/usr/bin/tmux", nil
}
if name == "tmux-img" {
return "/usr/local/bin/tmux-img", nil
}
return "", exec.ErrNotFound
},
Run: func(c *exec.Cmd) error { captured = c; return nil },
}
weird := "/tmp/imagen/o'malley's cat.png"
if err := s.Spawn(weird, "slug"); err != nil {
t.Fatalf("Spawn: %v", err)
}
shellCmd := captured.Args[5]
// Single-quoted with the embedded apostrophe escaped via the
// '\'' shell idiom — confirm we did not just splice the raw path.
if strings.Contains(shellCmd, "o'malley's") {
t.Errorf("shell cmd %q contains unescaped apostrophes", shellCmd)
}
}
func TestSpawn_MissingTmux(t *testing.T) {
s := &Spawner{
LookPath: func(string) (string, error) { return "", exec.ErrNotFound },
Run: func(*exec.Cmd) error { return nil },
}
err := s.Spawn("/x.png", "s")
if !errors.Is(err, ErrTmuxMissing) {
t.Errorf("err = %v, want ErrTmuxMissing", err)
}
}
func TestSpawn_MissingTmuxImg(t *testing.T) {
s := &Spawner{
LookPath: func(name string) (string, error) {
if name == "tmux" {
return "/usr/bin/tmux", nil
}
return "", exec.ErrNotFound
},
Run: func(*exec.Cmd) error { return nil },
}
err := s.Spawn("/x.png", "s")
if !errors.Is(err, ErrTmuxImgMissing) {
t.Errorf("err = %v, want ErrTmuxImgMissing", err)
}
}

160
internal/usage/usage.go Normal file
View File

@@ -0,0 +1,160 @@
// Package usage records per-call cost-tracking rows for the imagen CLI
// to mai.imagen_usage on Supabase. The writer is best-effort by design —
// the calling adapter logs failures and proceeds, because the image
// itself has already landed on disk by the time we record.
package usage
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
)
// Default REST schema is the mai schema where mai.imagen_usage lives.
const supabaseSchema = "mai"
// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/
// Content-Profile headers to target the mai schema instead of public.
type SupabaseSink struct {
URL string // SUPABASE_URL — e.g. https://msup.msbls.de
APIKey string // SUPABASE_SERVICE_KEY
HTTP *http.Client
}
// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY
// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use.
// Returns nil + ok=false if the env vars are not configured — the CLI
// uses that to skip cost-tracking gracefully.
func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &SupabaseSink{
URL: u,
APIKey: key,
HTTP: &http.Client{Timeout: 10 * time.Second},
}, true
}
type supabaseRow struct {
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed,omitempty"`
PromptHash string `json:"prompt_hash"`
LatencyMs int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"`
Caller string `json:"caller,omitempty"`
}
// Record inserts one row into mai.imagen_usage.
func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error {
body, err := json.Marshal(supabaseRow{
Backend: row.Backend,
Model: row.Model,
Seed: row.Seed,
PromptHash: row.PromptHash,
LatencyMs: row.LatencyMs,
CostUSDEstimate: row.CostUSDEstimate,
Caller: row.Caller,
})
if err != nil {
return fmt.Errorf("usage: marshal: %w", err)
}
endpoint := s.URL + "/rest/v1/imagen_usage"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=minimal")
resp, err := s.HTTP.Do(req)
if err != nil {
return fmt.Errorf("usage: POST: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody))
}
return nil
}
// Row is the read-side row shape (only the fields the CLI needs).
type Row struct {
CreatedAt time.Time `json:"created_at"`
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed"`
PromptHash string `json:"prompt_hash"`
LatencyMs *int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate"`
Caller *string `json:"caller"`
}
// Query returns rows from mai.imagen_usage filtered by created_at >= since.
// Pass zero time to fetch the full table (capped server-side by PostgREST
// — we set a hard 5000-row limit here too).
func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) {
q := url.Values{}
q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller")
q.Set("order", "created_at.desc")
q.Set("limit", "5000")
if !since.IsZero() {
q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339))
}
endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Accept-Profile", supabaseSchema)
resp, err := s.HTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("usage: GET: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body))
}
var rows []Row
if err := json.Unmarshal(body, &rows); err != nil {
return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body))
}
return rows, nil
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}

213
internal/worker/worker.go Normal file
View File

@@ -0,0 +1,213 @@
// Package worker consumes the imagen.jobs queue. It claims pending rows via
// an UPDATE-returning lock (single source of truth, no double-claim window),
// runs the supplied generation pipeline, then writes status + image_id back.
//
// The package is DB-agnostic: it talks to two small interfaces (Queue +
// Pipeline) so unit tests can drive the claim/transition logic with no real
// Postgres connection. cmd/imagen wires the pgx implementation.
package worker
import (
"context"
"errors"
"fmt"
"sync"
"time"
)
// Job is the slice of an imagen.jobs row the worker needs to drive a
// generation. Null columns from the DB are represented as zero values; the
// pipeline treats zero values as "use backend default" (same convention as
// backend.Request).
type Job struct {
ID string
OwnerUserID string
Prompt string
Backend string
Model string
Width int
Height int
Steps int
Seed int64
Style string
}
// Outcome is what the pipeline reports back per job. ImageID is the
// imagen.images.id the cloud-sync produced. Empty ImageID with nil Err means
// the cloud-sync was skipped (config off) — we treat that as a failure for
// the worker since flexsiebels needs the image_id to render the result.
type Outcome struct {
ImageID string
Err error
}
// Queue is the persistence layer for the imagen.jobs table. Implementations
// must be safe for serialised single-worker use (concurrent claim across
// multiple worker processes is out of scope for v1 — the FOR UPDATE SKIP
// LOCKED clause in the pgx claim query covers it cheaply anyway).
type Queue interface {
// ClaimNextPending atomically marks the oldest pending row 'running' and
// returns it. Returns (nil, nil) when the queue is empty.
ClaimNextPending(ctx context.Context) (*Job, error)
// MarkDone records success: status='done', image_id, completed_at=now().
MarkDone(ctx context.Context, jobID, imageID string) error
// MarkFailed records failure: status='failed', error=msg, completed_at=now().
MarkFailed(ctx context.Context, jobID, errMsg string) error
// WaitForJob blocks until either a NOTIFY arrives on imagen_jobs, the
// timeout expires, or ctx is cancelled. Returns nil on notification or
// timeout; returns ctx.Err() on cancellation. Transient connection errors
// are returned so the caller can decide to reconnect.
WaitForJob(ctx context.Context, timeout time.Duration) error
// ResetStaleRunning marks any rows stuck in 'running' (e.g. left over
// from a crash before this process started) back to 'pending'. Called
// once at worker startup so the cold-start safety poll can pick them up.
ResetStaleRunning(ctx context.Context) error
}
// Pipeline runs one generation and reports back the imagen.images.id (or an
// error). The implementation owns backend dispatch, prompt enrichment, disk
// write, and cloud-sync; the worker only orchestrates queue state.
type Pipeline interface {
Run(ctx context.Context, job Job) Outcome
}
// Config is the runtime knob set for the worker loop.
type Config struct {
// PollInterval is the safety-poll cadence between LISTEN wakeups. Picking
// this too low wastes DB roundtrips; too high lets a dropped NOTIFY
// stall the queue. 5s is the spec'd default.
PollInterval time.Duration
// JobTimeout caps any single Pipeline.Run. A backend hang shouldn't
// freeze the queue forever.
JobTimeout time.Duration
// Logger receives one-line status events. nil means silent.
Logger func(format string, args ...any)
}
// Worker is the orchestration loop. It is not reusable across Run calls.
type Worker struct {
q Queue
p Pipeline
cfg Config
// processingMu guards the in-flight job so SIGTERM-triggered shutdown
// waits for it to complete before returning.
processingMu sync.Mutex
}
// New constructs a Worker.
func New(q Queue, p Pipeline, cfg Config) *Worker {
if cfg.PollInterval <= 0 {
cfg.PollInterval = 5 * time.Second
}
if cfg.JobTimeout <= 0 {
cfg.JobTimeout = 5 * time.Minute
}
return &Worker{q: q, p: p, cfg: cfg}
}
// Run drives the consume loop until ctx is cancelled or a fatal queue error
// (e.g. unrecoverable DB drop) is returned. A LISTEN wait can fail with a
// transient transport error; the worker logs and continues so a temporary
// network blip doesn't take it down.
func (w *Worker) Run(ctx context.Context) error {
if err := w.q.ResetStaleRunning(ctx); err != nil {
w.log("worker: reset stale running rows: %v", err)
// Don't return — a stale row will eventually be visible to the poll
// path once flexsiebels gives up and resubmits, and we'd rather keep
// serving fresh jobs than crash here.
}
for {
if err := ctx.Err(); err != nil {
return nil
}
// Drain the queue: claim and process until empty.
if err := w.drain(ctx); err != nil && !errors.Is(err, context.Canceled) {
w.log("worker: drain: %v", err)
}
if err := ctx.Err(); err != nil {
return nil
}
// Wait for the next wake. WaitForJob covers both LISTEN and the
// timeout-based poll fallback; either returns nil and we loop.
if err := w.q.WaitForJob(ctx, w.cfg.PollInterval); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
w.log("worker: wait: %v (continuing)", err)
// Pace the retries so a totally-broken DB doesn't busy-spin.
select {
case <-ctx.Done():
return nil
case <-time.After(w.cfg.PollInterval):
}
}
}
}
// drain claims and processes every currently-pending job. The job-scoped
// context is derived from context.Background() so that a SIGTERM mid-job
// still lets the pipeline finish — that's the "no half-state on shutdown"
// guarantee the issue calls for.
func (w *Worker) drain(ctx context.Context) error {
for {
if err := ctx.Err(); err != nil {
return err
}
job, err := w.q.ClaimNextPending(ctx)
if err != nil {
return fmt.Errorf("claim: %w", err)
}
if job == nil {
return nil
}
w.processOne(*job)
}
}
// processOne runs the pipeline for one already-claimed job and writes the
// outcome back to the queue. The job context is independent of the outer
// ctx so an in-flight job can finish even after SIGTERM.
func (w *Worker) processOne(job Job) {
w.processingMu.Lock()
defer w.processingMu.Unlock()
w.log("worker: processing job %s backend=%s", job.ID, job.Backend)
jobCtx, cancel := context.WithTimeout(context.Background(), w.cfg.JobTimeout)
defer cancel()
out := w.p.Run(jobCtx, job)
// Status-update uses Background ctx with a short timeout — we must
// always be able to record the outcome, otherwise the row sits in
// 'running' forever.
updCtx, updCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer updCancel()
if out.Err != nil {
w.log("worker: job %s failed: %v", job.ID, out.Err)
if err := w.q.MarkFailed(updCtx, job.ID, out.Err.Error()); err != nil {
w.log("worker: mark failed for %s: %v", job.ID, err)
}
return
}
if out.ImageID == "" {
// Pipeline reported success but no imagen.images row — treat as
// failure because flexsiebels has nothing to link.
const msg = "pipeline did not return an imagen.images id (cloud sync misconfigured?)"
w.log("worker: job %s: %s", job.ID, msg)
if err := w.q.MarkFailed(updCtx, job.ID, msg); err != nil {
w.log("worker: mark failed for %s: %v", job.ID, err)
}
return
}
if err := w.q.MarkDone(updCtx, job.ID, out.ImageID); err != nil {
w.log("worker: mark done for %s: %v", job.ID, err)
return
}
w.log("worker: job %s done image_id=%s", job.ID, out.ImageID)
}
func (w *Worker) log(format string, args ...any) {
if w.cfg.Logger != nil {
w.cfg.Logger(format, args...)
}
}

View File

@@ -0,0 +1,332 @@
package worker
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
)
// fakeQueue is a hand-rolled in-memory queue that mirrors the contract of a
// real Postgres-backed implementation: ClaimNextPending atomically takes one
// pending row and flips its status to "running", MarkDone/MarkFailed are
// idempotent terminal transitions, WaitForJob blocks until notified or until
// the timeout elapses.
type fakeQueue struct {
mu sync.Mutex
pending []Job
state map[string]string // jobID -> status
last map[string]string // jobID -> error msg or image_id
notify chan struct{}
claimErr error
doneErr error
failErr error
resetErr error
claimed int
done int
failed int
resets int
}
func newFakeQueue(jobs ...Job) *fakeQueue {
q := &fakeQueue{
state: make(map[string]string),
last: make(map[string]string),
notify: make(chan struct{}, 16),
}
for _, j := range jobs {
q.pending = append(q.pending, j)
q.state[j.ID] = "pending"
}
return q
}
func (q *fakeQueue) ClaimNextPending(ctx context.Context) (*Job, error) {
q.mu.Lock()
defer q.mu.Unlock()
if q.claimErr != nil {
return nil, q.claimErr
}
if len(q.pending) == 0 {
return nil, nil
}
j := q.pending[0]
q.pending = q.pending[1:]
q.state[j.ID] = "running"
q.claimed++
return &j, nil
}
func (q *fakeQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.doneErr != nil {
return q.doneErr
}
q.state[jobID] = "done"
q.last[jobID] = imageID
q.done++
return nil
}
func (q *fakeQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.failErr != nil {
return q.failErr
}
q.state[jobID] = "failed"
q.last[jobID] = msg
q.failed++
return nil
}
func (q *fakeQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-q.notify:
return nil
case <-time.After(timeout):
return nil
}
}
func (q *fakeQueue) ResetStaleRunning(ctx context.Context) error {
q.mu.Lock()
defer q.mu.Unlock()
q.resets++
return q.resetErr
}
// pingNotify simulates an INSERT-trigger NOTIFY by waking WaitForJob.
func (q *fakeQueue) pingNotify() {
select {
case q.notify <- struct{}{}:
default:
}
}
// stub pipeline.
type fakePipeline struct {
mu sync.Mutex
results map[string]Outcome // by job.ID; "" key = default outcome
calls int
delay time.Duration
lastJob Job
}
func (p *fakePipeline) Run(ctx context.Context, job Job) Outcome {
p.mu.Lock()
p.calls++
p.lastJob = job
delay := p.delay
out, ok := p.results[job.ID]
if !ok {
out = p.results[""]
}
p.mu.Unlock()
if delay > 0 {
select {
case <-ctx.Done():
return Outcome{Err: ctx.Err()}
case <-time.After(delay):
}
}
return out
}
func TestWorker_DonePath(t *testing.T) {
q := newFakeQueue(
Job{ID: "j1", Prompt: "a", Backend: "mock"},
)
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img-1"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(80 * time.Millisecond)
cancel()
}()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if got := q.state["j1"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
if got := q.last["j1"]; got != "img-1" {
t.Fatalf("image_id=%q want img-1", got)
}
if q.done != 1 || q.failed != 0 {
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
}
if p.calls != 1 {
t.Fatalf("pipeline calls=%d want 1", p.calls)
}
if q.resets != 1 {
t.Fatalf("ResetStaleRunning calls=%d want 1", q.resets)
}
}
func TestWorker_FailedPath_RecordsErrorText(t *testing.T) {
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
p := &fakePipeline{results: map[string]Outcome{"j1": {Err: errors.New("backend unreachable")}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := q.state["j1"]; got != "failed" {
t.Fatalf("state=%q want failed", got)
}
if got := q.last["j1"]; got != "backend unreachable" {
t.Fatalf("error=%q want %q", got, "backend unreachable")
}
if q.done != 0 || q.failed != 1 {
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
}
}
func TestWorker_MissingImageID_TreatedAsFailure(t *testing.T) {
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
// Outcome has neither Err nor ImageID — pipeline silently swallowed
// cloud-sync. flexsiebels needs the image_id; without it, fail the job.
p := &fakePipeline{results: map[string]Outcome{"j1": {}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := q.state["j1"]; got != "failed" {
t.Fatalf("state=%q want failed", got)
}
if q.last["j1"] == "" {
t.Fatalf("expected non-empty error explanation for missing image_id")
}
}
func TestWorker_DrainsMultipleBeforeWaiting(t *testing.T) {
q := newFakeQueue(
Job{ID: "j1", Backend: "mock"},
Job{ID: "j2", Backend: "mock"},
Job{ID: "j3", Backend: "mock"},
)
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 200 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(60 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
for _, id := range []string{"j1", "j2", "j3"} {
if got := q.state[id]; got != "done" {
t.Fatalf("%s state=%q want done", id, got)
}
}
if q.done != 3 {
t.Fatalf("done=%d want 3", q.done)
}
}
func TestWorker_NotifyWakesEarlierThanPoll(t *testing.T) {
q := newFakeQueue()
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
// Set poll interval high so a working LISTEN is required to see the job
// promptly. Without NOTIFY plumbing this test would time out the worker
// before drain ever runs.
w := New(q, p, Config{PollInterval: 5 * time.Second, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
_ = w.Run(ctx)
close(done)
}()
// Append a job and ping the wake channel.
q.mu.Lock()
q.pending = append(q.pending, Job{ID: "late", Backend: "mock"})
q.state["late"] = "pending"
q.mu.Unlock()
q.pingNotify()
// Give the worker a beat to claim + process.
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
q.mu.Lock()
s := q.state["late"]
q.mu.Unlock()
if s == "done" {
cancel()
<-done
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("worker did not pick up the late job within the 500ms window — NOTIFY wake-up path is broken")
}
func TestWorker_HonoursContextCancellation(t *testing.T) {
q := newFakeQueue()
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
start := time.Now()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if dur := time.Since(start); dur > 200*time.Millisecond {
t.Fatalf("worker did not exit promptly on ctx cancel: %v", dur)
}
}
func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
q := newFakeQueue(Job{ID: "long", Backend: "mock"})
p := &fakePipeline{
results: map[string]Outcome{"long": {ImageID: "img-long"}},
delay: 120 * time.Millisecond,
}
// Short JobTimeout would also kill the in-flight job; give it enough
// budget so the test exercises the shutdown-during-job path.
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: 5 * time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() {
// Let the job start, then cancel mid-flight.
time.Sleep(30 * time.Millisecond)
cancel()
}()
_ = w.Run(ctx)
if got := q.state["long"]; got != "done" {
t.Fatalf("state=%q want done (in-flight job should finish even on shutdown)", got)
}
}
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
// First claim returns an error; the loop should log and try again on the
// next wake — it must not propagate the error and exit.
q := newFakeQueue(Job{ID: "j1", Backend: "mock"})
q.claimErr = fmt.Errorf("transient: connection reset")
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 20 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
// Heal the claim error after a beat so the second drain succeeds.
go func() {
time.Sleep(40 * time.Millisecond)
q.mu.Lock()
q.claimErr = nil
q.mu.Unlock()
}()
go func() {
time.Sleep(200 * time.Millisecond)
cancel()
}()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run returned: %v (transient claim errors should not kill the loop)", err)
}
if got := q.state["j1"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
}

24
scripts/comfyui.service Normal file
View File

@@ -0,0 +1,24 @@
[Unit]
Description=ComfyUI image generation server
Documentation=https://github.com/comfyanonymous/ComfyUI
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
User=m
Group=m
WorkingDirectory=/home/m/dev/comfyui
ExecStart=/home/m/dev/comfyui/.venv/bin/python /home/m/dev/comfyui/main.py \
--listen 0.0.0.0 --port 8188 \
--output-directory /home/m/dev/comfyui/output \
--temp-directory /home/m/dev/comfyui/temp
Restart=on-failure
RestartSec=5
TimeoutStopSec=30
NoNewPrivileges=true
PrivateTmp=true
LimitNOFILE=65535
[Install]
WantedBy=multi-user.target

View File

@@ -0,0 +1,37 @@
#!/bin/bash
# Download FLUX.1 schnell + accompanying VAE/text encoders into a ComfyUI tree.
# Uses ungated mirrors — the official Black-Forest-Labs repo is gated and
# requires an HF token. See docs/setup-comfyui-mrock.md.
set -euo pipefail
ROOT="${1:-$HOME/dev/comfyui/models}"
if [ ! -d "$ROOT" ]; then
echo "models root $ROOT does not exist — pass it as the first argument" >&2
exit 1
fi
mkdir -p "$ROOT/unet" "$ROOT/vae" "$ROOT/clip"
CKPT="https://huggingface.co/Comfy-Org/flux1-schnell/resolve/main/flux1-schnell.safetensors"
VAE="https://huggingface.co/sirorable/flux-ae-vae/resolve/main/ae.safetensors"
CLIP_L="https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors"
T5="https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors"
dl() {
local url=$1 dest=$2
if [ -s "$dest" ]; then
echo "skip $dest (already present)"
return
fi
echo "downloading $url -> $dest"
curl -L --fail --retry 3 --retry-delay 5 -C - -o "$dest" "$url"
}
dl "$CKPT" "$ROOT/unet/flux1-schnell.safetensors"
dl "$VAE" "$ROOT/vae/ae.safetensors"
dl "$CLIP_L" "$ROOT/clip/clip_l.safetensors"
dl "$T5" "$ROOT/clip/t5xxl_fp8_e4m3fn.safetensors"
echo "done"

View File

@@ -0,0 +1,87 @@
{
"prompt": {
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "a small fishbowl with a cat staring out, photo, soft light",
"clip": ["11", 0]
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["31", 0],
"vae": ["10", 0]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "imagen-poc",
"images": ["8", 0]
}
},
"10": {
"class_type": "VAELoader",
"inputs": {
"vae_name": "ae.safetensors"
}
},
"11": {
"class_type": "DualCLIPLoader",
"inputs": {
"clip_name1": "t5xxl_fp8_e4m3fn.safetensors",
"clip_name2": "clip_l.safetensors",
"type": "flux"
}
},
"12": {
"class_type": "UNETLoader",
"inputs": {
"unet_name": "flux1-schnell.safetensors",
"weight_dtype": "fp8_e4m3fn"
}
},
"13": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "",
"clip": ["11", 0]
}
},
"27": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
}
},
"30": {
"class_type": "ModelSamplingFlux",
"inputs": {
"model": ["12", 0],
"max_shift": 1.15,
"base_shift": 0.5,
"width": 1024,
"height": 1024
}
},
"31": {
"class_type": "KSampler",
"inputs": {
"model": ["30", 0],
"seed": 1234567,
"steps": 4,
"cfg": 1.0,
"sampler_name": "euler",
"scheduler": "simple",
"denoise": 1.0,
"positive": ["6", 0],
"negative": ["13", 0],
"latent_image": ["27", 0]
}
}
},
"client_id": "imagen-poc-001"
}

View File

@@ -0,0 +1,22 @@
# Environment for the imagen-worker.service systemd unit.
# Copy to ~/.dotfiles/.env.imagen-worker and fill in real values.
# Never commit the populated file — it carries the Supabase service-role key.
# Direct Postgres DSN for LISTEN/NOTIFY + imagen.jobs UPDATE statements.
# PostgREST cannot LISTEN, so the worker connects to Postgres directly.
# Host + port + password come from the msupabase compose env on mlake.
IMAGEN_WORKER_DATABASE_URL=postgres://postgres:CHANGE_ME@100.99.98.201:6789/postgres?sslmode=disable
# PostgREST endpoint for the imagen.images cloud-sync writer (same as
# `imagen generate`'s cloud-sync code path).
SUPABASE_URL=https://supa.flexsiebels.de
SUPABASE_SERVICE_KEY=CHANGE_ME
# Default owner_user_id. Per-job owner from the imagen.jobs row overrides
# this, so it's only used as a fallback when a job arrives with a NULL
# owner_user_id — which the schema disallows. Keep it set for safety.
IMAGEN_OWNER_USER_ID=ac6c9501-3757-4a6d-8b97-2cff4288382b
# Optional: REPLICATE_API_TOKEN if any imagen.jobs.backend may resolve to
# a Replicate adapter instance.
# REPLICATE_API_TOKEN=CHANGE_ME

View File

@@ -0,0 +1,19 @@
[Unit]
Description=ImaGen worker (consumes imagen.jobs queue)
Documentation=https://mgit.msbls.de/m/ImaGen/issues/8
Wants=network-online.target
After=network-online.target
[Service]
Type=simple
ExecStart=%h/dev/ImaGen/bin/imagen worker
WorkingDirectory=%h/dev/ImaGen
EnvironmentFile=%h/.dotfiles/.env.imagen-worker
Restart=on-failure
RestartSec=5
# Give the worker time to finish an in-flight generation on shutdown
# (FLUX dev up to ~30s, plus the cloud-sync write-back).
TimeoutStopSec=60
[Install]
WantedBy=default.target