6 Commits

Author SHA1 Message Date
mAi
2758c5a500 mAi: #8 - imagen.jobs queue + worker subcommand (flexsiebels write path)
Async write path for the flexsiebels owner-mode UI: flexsiebels INSERTs into
imagen.jobs, the worker on mRiver claims pending rows via LISTEN/NOTIFY +
5s safety poll, runs the same generate pipeline imagen generate uses, and
writes the result through internal/cloud into imagen.images.

- Schema migration imagen_jobs_init: table + status CHECK + two indexes +
  owner-scoped RLS + grants + AFTER INSERT trigger publishing on the
  imagen_jobs channel via pg_notify.
- internal/worker: DB-agnostic loop over a Queue interface. Drains the
  whole pending backlog on each wake. Job-scoped contexts are derived
  from Background so SIGTERM lets the in-flight generation finish (no
  half-state). ResetStaleRunning at startup unsticks rows left over from
  a previous crash. Eight unit tests cover the done / failed / missing-id /
  drain / NOTIFY-wake / shutdown / transient-error paths against a fake
  queue (no real Postgres in CI).
- cmd/imagen/worker.go: pgx-backed Queue (one dedicated conn for LISTEN +
  UPDATE), plus the workerPipeline that reuses buildBackend +
  attachUsageSink + prompt.Apply + buildWriter + maybeCloudSync. The
  per-job owner_user_id overrides the env-level fallback so each row in
  imagen.images is attributed correctly.
- maybeCloudSync now returns (*cloud.SyncResult, error) so the worker can
  link imagen.jobs.image_id to the inserted imagen.images row. The CLI
  generate path keeps printing its stderr summary unchanged.
- scripts/imagen-worker.service + .env.example for the systemd --user unit
  on mRiver. EnvironmentFile lives in ~/.dotfiles and is never committed.
- docs/setup-worker-mriver.md walks through installation + the spec's
  SQL-INSERT smoke; docs/architecture.md grows an "async write path"
  section.
- worker_integration_test.go (env-guarded by IMAGEN_WORKER_INTEGRATION=1)
  drives one real job through the full pipeline against msupabase using
  the mock backend, then verifies imagen.images + Storage object landed
  and the row flipped to done with image_id linked. Verified end-to-end:
  pickup latency ~7ms, total 74ms, failure path captures error text.
2026-05-11 10:23:33 +02:00
mAi
cb6656c436 Merge mai/hermes/issue-7-imagen-7-cloud: Supabase cloud-sync for flexsiebels viewer (#7) 2026-05-11 01:53:12 +02:00
mAi
e22f286024 mAi: #7 - cloud-sync to Supabase Storage + imagen.images
Every successful imagen generate now (a) uploads the PNG to the private
imagen-generated bucket and (b) inserts a row into imagen.images, the
data plane the flexsiebels owner-mode viewer reads from.

Schema, RLS, indexes, bucket and PostgREST exposure landed via four
applied migrations on msupabase: imagen_schema_init,
imagen_schema_grants, imagen_storage_policies, imagen_pgrst_expose
(authenticator role-level ALTER + reload). Owner UUID for m:
ac6c9501-3757-4a6d-8b97-2cff4288382b — documented in the config sample.

Code: new internal/cloud/ package mirroring the internal/usage/ shape.
PostgREST POST against the imagen schema (Accept-Profile + Content-
Profile headers), Storage upload via PUT with x-upsert, retry on 5xx /
transport but not 4xx, owner_user_id required (the column is NOT NULL
and the read-side RLS policy needs it).

Wiring in cmd/imagen/generate.go: --no-cloud flag, output.cloud_sync
config knob (auto|on|off mirroring --preview), $IMAGEN_CLOUD_SYNC env
override. The hook reads the just-written PNG + sidecar from disk and
calls cloud.Sync; failures emit "imagen: cloud sync: <err>" to stderr
without changing exit code, so a Supabase blip never loses the artefact.
output.Outputs grew Date/Slug/Seed fields so storage_path mirrors the
local filename's prefix exactly (no UTC-vs-local drift).

Config: owner_user_id field added; sample comment points at the
auth.users lookup. imagen config validate warns on stderr when
cloud_sync is on/auto but owner_user_id is empty.

Tests: cloud_test.go covers happy path, retry-on-5xx, no-retry-on-4xx,
missing-owner-uuid, missing-date-or-slug, signed URL, and the partial-
success case where the upload landed but the DB insert failed.
generate_test.go covers the precedence chain for cloud-sync mode
resolution. Build + tests clean across the tree.

Real smoke against mRock: generation through flux-schnell-local writes
the local PNG + sidecar AND uploads to imagen-generated/2026-05-11/...
AND inserts into imagen.images. Signed URL round-trips the same bytes.
--no-cloud verified to skip both Storage and DB.
2026-05-11 01:51:09 +02:00
mAi
2d5896e27d Merge mai/hermes/issue-3-imagen-3: Replicate API backend + cost-tracking + usage CLI (#3) 2026-05-08 17:32:09 +02:00
mAi
b282325663 mAi: #3 - Replicate adapter, mai.imagen_usage cost-tracking, usage CLI
Implements the Replicate API backend (FLUX schnell / FLUX dev) per ImaGen
issue #3:

- internal/backend/replicate.go — Backend adapter. Supports model
  refs as "owner/name" (uses /v1/models/{owner}/{name}/predictions) and
  "owner/name:hash" (uses /v1/predictions with explicit version). Polls
  /v1/predictions/{id} every 500ms with model-aware timeout (60s schnell,
  120s dev). Resilience: 401 names api_token_env, 429 with exp backoff
  up to 3 retries (honours Retry-After), 5xx retries once, image
  download retries once on transient failure.
- internal/backend/replicate_pricing.go — hardcoded per-image USD rates
  for known FLUX models, snapshotted from replicate.com/pricing with a
  refresh TODO.
- internal/backend/replicate_test.go — mocked-HTTP unit tests covering
  happy path (model + version-pinned), 401, 429 retry policy, failed
  prediction, poll timeout, image-download retry, ctx cancel, BackendOpts
  passthrough, default_steps, aspect-ratio reduction, sha256 prompt hash.
- internal/usage/usage.go — Supabase REST sink + read-side query for
  mai.imagen_usage. Adapter writes are best-effort: failures warn but
  the image still lands.
- cmd/imagen/usage.go — `imagen usage [--since DATE] [--raw]` reads
  the table and prints a tab-aligned grouped or raw table with totals.
- cmd/imagen/backends.go — instances of type=replicate now report
  "ok" or "not configured (set REPLICATE_API_TOKEN)" depending on env.
- internal/config/config.go — sample adds flux-schnell-replicate +
  flux-dev-replicate; default_backend stays flux-schnell-local.
- Supabase migration mai.imagen_usage (id, created_at, backend, model,
  seed, prompt_hash, latency_ms, cost_usd_estimate, caller) + indexes
  on (created_at DESC) and (caller). The raw prompt is never stored.

Caller identity resolves from MAI_FROM_ID, then the tmux pane's
@mai-name option, mirroring the maimcp identity logic. Prompt hash is
sha256 of the user-facing prompt; raw prompt never reaches the table.
2026-05-08 17:28:29 +02:00
mAi
a1d0165445 Merge mai/hermes/issue-5-imagen-5-tmux: tmux-window preview for generate (#5) 2026-05-08 17:12:57 +02:00
27 changed files with 3908 additions and 19 deletions

View File

@@ -10,13 +10,17 @@ and lifecycle of its own block in `~/.config/imagen.yaml`.
## Architecture
```
cmd/imagen/ CLI shell — generate, backends, config, serve
cmd/imagen/ CLI shell — generate, worker, backends, config, serve
internal/backend/ Backend interface + Registry + Mock reference impl
internal/prompt/ Style preset registry (embedded styles.yaml)
internal/output/ Filename templating, image writer, JSON sidecar
internal/config/ YAML loader, validation, sample generator
internal/cloud/ Supabase Storage + imagen.images writer
internal/usage/ mai.imagen_usage cost-tracking sink
internal/worker/ imagen.jobs queue consumer (DB-agnostic via Queue interface)
internal/server/ HTTP stub (not implemented yet — follow-up issue)
docs/ architecture.md, usage.md
scripts/ imagen-worker.service + env template, ComfyUI scripts
docs/ architecture.md, usage.md, setup-worker-mriver.md
```
Data flow for `imagen generate`:

View File

@@ -10,6 +10,27 @@ import (
"mgit.msbls.de/m/ImaGen/internal/config"
)
// instanceStatus checks adapter-specific preconditions (e.g. the
// Replicate API token env var being set) and returns a short
// user-facing status string.
func instanceStatus(spec config.BackendSpec) string {
if !backend.Default.Has(spec.Type) {
return fmt.Sprintf("type %q not compiled in", spec.Type)
}
switch spec.Type {
case backend.ReplicateType:
envName, _ := spec.Raw["api_token_env"].(string)
if envName == "" {
envName = "REPLICATE_API_TOKEN"
}
if os.Getenv(envName) == "" {
return fmt.Sprintf("not configured (set %s)", envName)
}
return "ok"
}
return "registered"
}
func runBackends(args []string) error {
fs := flag.NewFlagSet("backends", flag.ContinueOnError)
var configPath string
@@ -27,10 +48,7 @@ func runBackends(args []string) error {
fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS")
if cfg != nil {
for name, spec := range cfg.Backends {
status := "registered"
if !backend.Default.Has(spec.Type) {
status = fmt.Sprintf("type %q not compiled in", spec.Type)
}
status := instanceStatus(spec)
marker := ""
if name == cfg.DefaultBackend {
marker = " (default)"

View File

@@ -39,6 +39,18 @@ func runConfig(args []string) error {
}
fmt.Fprintf(os.Stdout, "OK — %d backend(s) defined, default=%q\n",
len(cfg.Backends), cfg.DefaultBackend)
// Soft warnings — surfaced on stderr so they're visible but don't
// fail the validate exit code.
cloudMode := cfg.Output.CloudSync
if cloudMode == "" {
cloudMode = "auto"
}
if cloudMode != "off" && cfg.OwnerUserID == "" {
fmt.Fprintln(os.Stderr,
"warning: cloud_sync is "+cloudMode+" but owner_user_id is empty — DB inserts will be skipped.")
fmt.Fprintln(os.Stderr,
" look it up: SELECT id FROM auth.users WHERE email = '<your-supabase-email>';")
}
return nil
default:
return userErr("unknown config subcommand %q (init|validate|path)", args[0])

View File

@@ -2,17 +2,22 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/cloud"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/preview"
"mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
func runGenerate(ctx context.Context, args []string) error {
@@ -29,6 +34,7 @@ func runGenerate(ctx context.Context, args []string) error {
noSidecar bool
previewOn bool
previewOff bool
noCloud bool
)
fs.StringVar(&backendName, "backend", "", "backend instance name (default: config.default_backend)")
fs.StringVar(&size, "size", "1024x1024", "WxH, e.g. 1024x1024")
@@ -41,6 +47,7 @@ func runGenerate(ctx context.Context, args []string) error {
fs.BoolVar(&noSidecar, "no-sidecar", false, "skip the JSON sidecar even if config enables it")
fs.BoolVar(&previewOn, "preview", false, "force tmux preview window on (errors outside $TMUX)")
fs.BoolVar(&previewOff, "no-preview", false, "skip the tmux preview window")
fs.BoolVar(&noCloud, "no-cloud", false, "skip Supabase upload + imagen.images insert for this generation")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen generate "<prompt>" [flags]`)
fs.PrintDefaults()
@@ -81,6 +88,7 @@ func runGenerate(ctx context.Context, args []string) error {
if err != nil {
return err
}
attachUsageSink(be)
finalPrompt, err := prompt.Apply(rawPrompt, style)
if err != nil {
@@ -124,6 +132,13 @@ func runGenerate(ctx context.Context, args []string) error {
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
}
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", paths, in, res, w, h); err != nil {
// cloud-sync failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
} else if result != nil && result.ImageID != "" {
fmt.Fprintf(os.Stderr, "cloud: imagen.images.id=%s storage_path=%s\n", result.ImageID, result.StoragePath)
}
if err := maybePreview(cfg, previewOn, previewOff, paths.ImagePath, rawPrompt); err != nil {
// preview failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: preview:", err)
@@ -131,6 +146,174 @@ func runGenerate(ctx context.Context, args []string) error {
return nil
}
// resolveCloudSyncMode applies the precedence chain config -> env -> flag.
// Flags win, env beats config, config beats the implicit auto default.
// Mirrors resolvePreviewMode shape.
func resolveCloudSyncMode(cfg *config.Config, noCloudFlag bool, env string) (string, error) {
mode := "auto"
if cfg != nil && cfg.Output.CloudSync != "" {
mode = cfg.Output.CloudSync
}
if env != "" {
switch env {
case "auto", "on", "off":
mode = env
default:
return "", fmt.Errorf("$IMAGEN_CLOUD_SYNC = %q (must be auto|on|off)", env)
}
}
if noCloudFlag {
mode = "off"
}
return mode, nil
}
// maybeCloudSync resolves the effective mode and, if it says yes, uploads
// the PNG and inserts the row. Returns the SyncResult on success so callers
// that need the imagen.images.id (e.g. the worker linking a job row) can pick
// it up. ownerOverride, when non-empty, wins over config + env — the worker
// passes the job row's owner_user_id so each job is attributed correctly.
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
if err != nil {
return nil, err
}
if mode == "off" {
return nil, nil
}
sink, ok := cloud.NewFromEnv()
if !ok {
if mode == "on" {
return nil, fmt.Errorf("cloud_sync=on but SUPABASE_URL / SUPABASE_SERVICE_KEY not set in env")
}
// auto + missing env = silent skip.
return nil, nil
}
switch {
case ownerOverride != "":
sink.OwnerUserID = ownerOverride
case cfg != nil && cfg.OwnerUserID != "":
// Config-supplied owner_user_id takes precedence over $IMAGEN_OWNER_USER_ID.
sink.OwnerUserID = cfg.OwnerUserID
}
if sink.OwnerUserID == "" {
if mode == "on" {
return nil, fmt.Errorf("cloud_sync=on but owner_user_id not set in config and $IMAGEN_OWNER_USER_ID is empty")
}
// auto + missing UUID = silent skip.
return nil, nil
}
pngBytes, readErr := os.ReadFile(paths.ImagePath)
if readErr != nil {
return nil, fmt.Errorf("read local image: %w", readErr)
}
// Reuse the writer's date/slug/seed so storage_path mirrors the local
// filename's prefix exactly — viewers can join `imagen.images` on
// either side without timezone drift.
date := paths.Date
slug := paths.Slug
if date == "" || slug == "" {
now := time.Now()
date = now.Format("2006-01-02")
slug = output.Slug(in.Prompt)
}
ext := in.Ext
if ext == "" {
ext = strings.TrimPrefix(filepath.Ext(paths.ImagePath), ".")
}
if ext == "" {
ext = "png"
}
// Snapshot the sidecar (if it exists) so the row carries the same
// metadata view a downstream viewer would see on disk.
var sidecar map[string]any
if paths.SidecarPath != "" {
if scBytes, err := os.ReadFile(paths.SidecarPath); err == nil {
_ = json.Unmarshal(scBytes, &sidecar)
}
}
model := metaString(res.Metadata, "model")
steps := metaInt(res.Metadata, "steps")
cost := metaFloatPtr(res.Metadata, "cost_usd_estimate")
latency := metaInt(res.Metadata, "latency_ms")
seed := paths.Seed
if seed == 0 {
seed = in.Seed
}
syncReq := cloud.SyncRequest{
Date: date,
Slug: slug,
Seed: seed,
Ext: ext,
PNG: pngBytes,
MimeType: res.MimeType,
Prompt: in.Prompt,
Backend: in.Backend,
Model: model,
Steps: steps,
Width: width,
Height: height,
LatencyMs: latency,
CostUSDEstimate: cost,
Sidecar: sidecar,
}
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
return sink.Sync(syncCtx, syncReq)
}
func metaString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func metaInt(m map[string]any, key string) int {
v, ok := m[key]
if !ok {
return 0
}
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 0
}
func metaFloatPtr(m map[string]any, key string) *float64 {
v, ok := m[key]
if !ok {
return nil
}
switch n := v.(type) {
case float64:
return &n
case float32:
f := float64(n)
return &f
case int:
f := float64(n)
return &f
case int64:
f := float64(n)
return &f
}
return nil
}
// resolvePreviewMode applies the precedence chain config -> env -> flag.
// Flags win, env beats config, config beats the implicit auto default.
func resolvePreviewMode(cfg *config.Config, flagOn, flagOff bool, env string) (preview.Mode, error) {
@@ -219,6 +402,21 @@ func parseSize(s string) (int, int, error) {
return w, h, nil
}
// attachUsageSink wires a Supabase cost-tracking sink into the backend
// when it accepts one and the env is configured. Adapters that record
// usage expose a public Sink field of type backend.UsageSink.
func attachUsageSink(be backend.Backend) {
r, ok := be.(*backend.Replicate)
if !ok {
return
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return
}
r.Sink = sink
}
func buildBackend(cfg *config.Config, name string) (backend.Backend, error) {
if cfg != nil {
spec, ok := cfg.Backends[name]

View File

@@ -48,3 +48,40 @@ func TestResolvePreviewMode(t *testing.T) {
})
}
}
func TestResolveCloudSyncMode(t *testing.T) {
type tc struct {
name string
cfg *config.Config
noCloud bool
env string
want string
wantError bool
}
cases := []tc{
{name: "all-empty-defaults-to-auto", want: "auto"},
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, want: "on"},
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "off"}}, want: "off"},
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "off", want: "off"},
{name: "flag-overrides-env-and-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "on", noCloud: true, want: "off"},
{name: "flag-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, noCloud: true, want: "off"},
{name: "bad-env-errors", env: "yes", wantError: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := resolveCloudSyncMode(c.cfg, c.noCloud, c.env)
if c.wantError {
if err == nil {
t.Fatalf("expected error, got mode %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != c.want {
t.Errorf("mode = %q, want %q", got, c.want)
}
})
}
}

View File

@@ -14,14 +14,16 @@ import (
_ "mgit.msbls.de/m/ImaGen/internal/backend"
)
const usage = `imagen — model-agnostic image generation
const helpText = `imagen — model-agnostic image generation
Usage:
imagen generate <prompt> [flags] generate one image
imagen worker [flags] consume the imagen.jobs queue (daemon)
imagen backends list registered backend types
imagen config init print a sample imagen.yaml on stdout
imagen config validate validate the active config
imagen serve [--addr :8080] (stub) start the HTTP server
imagen usage [--since DATE] show cost-tracking rows
imagen version print version
imagen help show this help
@@ -33,7 +35,7 @@ var Version = "dev"
func main() {
if len(os.Args) < 2 {
fmt.Fprint(os.Stderr, usage)
fmt.Fprint(os.Stderr, helpText)
os.Exit(2)
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
@@ -44,18 +46,22 @@ func main() {
switch os.Args[1] {
case "generate":
err = runGenerate(ctx, args)
case "worker":
err = runWorker(ctx, args)
case "backends":
err = runBackends(args)
case "config":
err = runConfig(args)
case "serve":
err = runServe(args)
case "usage":
err = runUsage(ctx, args)
case "version", "-v", "--version":
fmt.Println(Version)
case "help", "-h", "--help":
fmt.Print(usage)
fmt.Print(helpText)
default:
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage)
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText)
os.Exit(2)
}
if err != nil {

189
cmd/imagen/usage.go Normal file
View File

@@ -0,0 +1,189 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"sort"
"strings"
"text/tabwriter"
"time"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage
// via Supabase REST and prints a tab-aligned table grouped by week +
// backend + model + caller, with totals at the bottom.
func runUsage(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("usage", flag.ContinueOnError)
var (
since string
raw bool
)
fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date")
fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]")
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
var sinceT time.Time
if since != "" {
t, err := time.Parse("2006-01-02", since)
if err != nil {
return userErr("--since must be YYYY-MM-DD: %v", err)
}
sinceT = t
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage")
}
rows, err := sink.Query(ctx, sinceT)
if err != nil {
return err
}
if raw {
printRawRows(rows)
return nil
}
printGroupedRows(rows)
return nil
}
func printRawRows(rows []usage.Row) {
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD")
var totalCost float64
for _, r := range rows {
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n",
r.CreatedAt.Local().Format("2006-01-02 15:04"),
r.Backend,
r.Model,
derefString(r.Caller),
intOrDash(r.LatencyMs),
costOrDash(r.CostUSDEstimate),
)
if r.CostUSDEstimate != nil {
totalCost += *r.CostUSDEstimate
}
}
fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost)
_ = tw.Flush()
}
type group struct {
week string
backend string
model string
caller string
count int
cost float64
costSet bool
}
type groupKey struct {
week, backend, model, caller string
}
func printGroupedRows(rows []usage.Row) {
groups := map[groupKey]*group{}
for _, r := range rows {
caller := derefString(r.Caller)
k := groupKey{
week: weekStart(r.CreatedAt).Format("2006-01-02"),
backend: r.Backend,
model: r.Model,
caller: caller,
}
g, ok := groups[k]
if !ok {
g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller}
groups[k] = g
}
g.count++
if r.CostUSDEstimate != nil {
g.cost += *r.CostUSDEstimate
g.costSet = true
}
}
keys := make([]groupKey, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].week != keys[j].week {
return keys[i].week > keys[j].week // newest first
}
if keys[i].backend != keys[j].backend {
return keys[i].backend < keys[j].backend
}
if keys[i].model != keys[j].model {
return keys[i].model < keys[j].model
}
return keys[i].caller < keys[j].caller
})
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD")
var totalCount int
var totalCost float64
for _, k := range keys {
g := groups[k]
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n",
g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet),
)
totalCount += g.count
totalCost += g.cost
}
fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost)
_ = tw.Flush()
}
// weekStart returns the Monday of the week containing t (UTC).
func weekStart(t time.Time) time.Time {
t = t.UTC()
wd := int(t.Weekday())
if wd == 0 {
wd = 7 // shift Sunday to end-of-week
}
delta := time.Duration(wd-1) * -24 * time.Hour
d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
return d.Add(delta)
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
func intOrDash(p *int) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%d", *p)
}
func costOrDash(p *float64) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%.4f", *p)
}
func costStr(v float64, set bool) string {
if !set {
return "-"
}
return strings.TrimSpace(fmt.Sprintf("%.4f", v))
}

287
cmd/imagen/worker.go Normal file
View File

@@ -0,0 +1,287 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/jackc/pgx/v5"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/worker"
)
// runWorker is the `imagen worker` subcommand: a long-running daemon that
// consumes the imagen.jobs queue and writes results into imagen.images via
// the same cloud-sync path generate uses.
func runWorker(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("worker", flag.ContinueOnError)
var (
configPath string
pollInterval time.Duration
jobTimeout time.Duration
)
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.DurationVar(&pollInterval, "poll-interval", 5*time.Second, "safety-poll cadence between LISTEN wakeups")
fs.DurationVar(&jobTimeout, "job-timeout", 5*time.Minute, "max wall-time per job before the worker marks it failed")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen worker [flags]
Long-running daemon. LISTENs on the Postgres 'imagen_jobs' channel and polls
imagen.jobs every --poll-interval as a safety net, claims pending rows, runs
the generation pipeline, then updates the row with status + image_id.
Env:
IMAGEN_WORKER_DATABASE_URL Postgres DSN for direct LISTEN + UPDATE.
Required (PostgREST cannot LISTEN).
SUPABASE_URL, SUPABASE_SERVICE_KEY, IMAGEN_OWNER_USER_ID
Reused from generate's cloud-sync path; the
worker writes imagen.images rows through the
same code path. Per-job owner_user_id from the
job row overrides IMAGEN_OWNER_USER_ID.`)
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
cfg, cfgErr := config.Load(configPath)
if cfgErr != nil && !os.IsNotExist(cfgErr) {
return cfgErr
}
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
if dsn == "" {
return userErr("IMAGEN_WORKER_DATABASE_URL not set; the worker needs a direct Postgres DSN for LISTEN/NOTIFY")
}
q, err := dialQueue(ctx, dsn)
if err != nil {
return fmt.Errorf("queue: %w", err)
}
defer q.Close()
p := &workerPipeline{cfg: cfg}
w := worker.New(q, p, worker.Config{
PollInterval: pollInterval,
JobTimeout: jobTimeout,
Logger: func(format string, a ...any) { fmt.Fprintf(os.Stderr, format+"\n", a...) },
})
fmt.Fprintln(os.Stderr, "imagen worker: ready (poll-interval", pollInterval, "job-timeout", jobTimeout, ")")
return w.Run(ctx)
}
// pgxQueue is the production Queue. It opens one dedicated connection used
// for both LISTEN (long-lived) and UPDATE operations. A second connection
// would split state needlessly — a single worker process processes one job
// at a time so the connection is never contended.
type pgxQueue struct {
conn *pgx.Conn
}
func dialQueue(ctx context.Context, dsn string) (*pgxQueue, error) {
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("pgx.Connect: %w", err)
}
if _, err := conn.Exec(ctx, "LISTEN imagen_jobs"); err != nil {
conn.Close(ctx)
return nil, fmt.Errorf("LISTEN imagen_jobs: %w", err)
}
return &pgxQueue{conn: conn}, nil
}
func (q *pgxQueue) Close() {
if q == nil || q.conn == nil {
return
}
// Best-effort: a 5s budget is enough to send a polite TerminateMessage.
shutdown, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = q.conn.Close(shutdown)
}
// ClaimNextPending atomically marks the oldest pending row 'running' and
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
// process — out of scope for v1 but cheap insurance.
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
const stmt = `
UPDATE imagen.jobs
SET status='running', started_at=now()
WHERE id = (
SELECT id FROM imagen.jobs
WHERE status='pending'
ORDER BY created_at
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING id, owner_user_id, prompt, backend,
COALESCE(model,''),
COALESCE(width, 0), COALESCE(height, 0),
COALESCE(steps, 0), COALESCE(seed, 0),
COALESCE(style,'')`
var j worker.Job
err := q.conn.QueryRow(ctx, stmt).Scan(
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
return &j, nil
}
func (q *pgxQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='done', image_id=$2, completed_at=now() WHERE id=$1`,
jobID, imageID)
return err
}
func (q *pgxQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
// Trim outrageously long error text so a 10MB stack-trace doesn't end up
// in the row (callers see a summary, full text goes to stderr / logs).
const maxLen = 2000
if len(msg) > maxLen {
msg = msg[:maxLen] + "... [truncated]"
}
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='failed', error=$2, completed_at=now() WHERE id=$1`,
jobID, msg)
return err
}
// WaitForJob blocks until a NOTIFY arrives on imagen_jobs, the timeout fires,
// or ctx is cancelled. Notifications during a previous processJob are queued
// by pgx and delivered on the next call — we don't lose wake-ups even when
// processing took longer than poll-interval.
func (q *pgxQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
_, err := q.conn.WaitForNotification(waitCtx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil // poll cadence fired
}
if errors.Is(err, context.Canceled) {
return context.Canceled
}
return err
}
return nil
}
// ResetStaleRunning bumps any rows stuck in 'running' back to 'pending' so
// they get re-claimed. Called once at startup. A row stuck in 'running' came
// from a previous worker crash; without this, flexsiebels would poll
// forever on a job nobody is processing.
func (q *pgxQueue) ResetStaleRunning(ctx context.Context) error {
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='pending', started_at=NULL WHERE status='running'`)
return err
}
// workerPipeline is the Pipeline implementation that drives a single job
// through buildBackend → prompt enrichment → generate → write disk →
// cloud-sync, then returns the imagen.images.id back to the worker so it
// can link the row.
type workerPipeline struct {
cfg *config.Config
}
func (p *workerPipeline) Run(ctx context.Context, job worker.Job) worker.Outcome {
if job.OwnerUserID == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: missing owner_user_id", job.ID)}
}
if job.Prompt == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: empty prompt", job.ID)}
}
if job.Backend == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: missing backend", job.ID)}
}
be, err := buildBackend(p.cfg, job.Backend)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("backend %q: %w", job.Backend, err)}
}
attachUsageSink(be)
finalPrompt, err := prompt.Apply(job.Prompt, job.Style)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("style: %w", err)}
}
req := backend.Request{
Prompt: finalPrompt,
Width: job.Width,
Height: job.Height,
Steps: job.Steps,
Seed: job.Seed,
Style: job.Style,
}
res, err := be.Generate(ctx, req)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("generate: %w", err)}
}
defer res.ImageReader.Close()
writer := buildWriter(p.cfg, false)
in := output.Inputs{
Prompt: job.Prompt,
Backend: be.Name(),
Seed: seedFromMetadata(res.Metadata, job.Seed),
Ext: extFromMime(res.MimeType),
Metadata: res.Metadata,
}
paths, err := writer.Write(res.ImageReader, in)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("write disk: %w", err)}
}
// Worker is queue-driven: cloud-sync is mandatory because flexsiebels
// needs imagen.images.id to render the result. Pass cloud_sync=on via
// the override path (third arg = ownerUserID); we set the mode by
// disallowing the 'off' branch through the cfg later if the user
// explicitly turned it off in config.
if cloudModeOff(p.cfg) {
// We refuse to silently drop a queued job. If cloud sync is off in
// config, the worker can't serve flexsiebels at all.
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
}
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
if syncErr != nil {
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
}
if syncRes == nil || syncRes.ImageID == "" {
return worker.Outcome{Err: fmt.Errorf("cloud sync returned no imagen.images id (check SUPABASE_URL + SUPABASE_SERVICE_KEY)")}
}
return worker.Outcome{ImageID: syncRes.ImageID}
}
func cloudModeOff(cfg *config.Config) bool {
if cfg == nil {
return false
}
return strings.EqualFold(cfg.Output.CloudSync, "off")
}
// dimOrFallback returns job.<dim> when the job specified one, otherwise the
// dimension reported by the backend's metadata. Some backends (Replicate
// when given an aspect ratio) round the requested size to their nearest
// supported value; this keeps the row honest about what was actually generated.
func dimOrFallback(jobDim int, res *backend.Result, key string) int {
if jobDim > 0 {
return jobDim
}
return metaInt(res.Metadata, key)
}

View File

@@ -0,0 +1,129 @@
package main
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/worker"
)
// TestWorker_Integration_EndToEnd runs the full pipeline against a real
// msupabase instance: insert a row into imagen.jobs, let the worker claim
// it, generate via the mock backend (no Replicate spend, no ComfyUI
// dependency), write to Supabase Storage + imagen.images, then flip the job
// to 'done' with the linked image_id.
//
// Guarded by IMAGEN_WORKER_INTEGRATION=1. Required env beyond that:
//
// IMAGEN_WORKER_DATABASE_URL postgres DSN (direct, not PostgREST)
// SUPABASE_URL e.g. https://supa.flexsiebels.de
// SUPABASE_SERVICE_KEY service-role JWT
// IMAGEN_OWNER_USER_ID UUID of an auth.users row (RLS fallback)
//
// The test creates and later deletes its own job row so repeated runs don't
// leave debris.
func TestWorker_Integration_EndToEnd(t *testing.T) {
if os.Getenv("IMAGEN_WORKER_INTEGRATION") != "1" {
t.Skip("set IMAGEN_WORKER_INTEGRATION=1 to run the integration test")
}
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
if dsn == "" {
t.Fatal("IMAGEN_WORKER_DATABASE_URL must be set for the integration test")
}
if os.Getenv("SUPABASE_URL") == "" || os.Getenv("SUPABASE_SERVICE_KEY") == "" {
t.Fatal("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set for the integration test")
}
owner := os.Getenv("IMAGEN_OWNER_USER_ID")
if owner == "" {
t.Fatal("IMAGEN_OWNER_USER_ID must be set for the integration test")
}
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
q, err := dialQueue(ctx, dsn)
if err != nil {
t.Fatalf("dialQueue: %v", err)
}
defer q.Close()
// Insert the test job on a separate connection (the worker's conn is
// busy LISTENing). Mock backend = no external dependency.
insertConn, err := pgx.Connect(ctx, dsn)
if err != nil {
t.Fatalf("insert conn: %v", err)
}
defer insertConn.Close(ctx)
var jobID string
prompt := fmt.Sprintf("imagen integration test %d", time.Now().UnixNano())
err = insertConn.QueryRow(ctx, `
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
VALUES ($1, $2, 'mock', 64, 64)
RETURNING id`,
owner, prompt).Scan(&jobID)
if err != nil {
t.Fatalf("insert job: %v", err)
}
t.Logf("inserted imagen.jobs id=%s", jobID)
// Tidy up at the end of the test so a re-run starts clean.
defer func() {
cleanup, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _ = insertConn.Exec(cleanup, `DELETE FROM imagen.jobs WHERE id=$1`, jobID)
}()
// Use a per-test temp dir so the generated PNG doesn't litter the repo.
tmpDir := t.TempDir()
cfg := &config.Config{Output: config.OutputConfig{Directory: tmpDir}}
p := &workerPipeline{cfg: cfg}
w := worker.New(q, p, worker.Config{
PollInterval: 1 * time.Second,
JobTimeout: 30 * time.Second,
Logger: func(format string, a ...any) { t.Logf("worker: "+format, a...) },
})
// Run the worker until it processes one job (the one we just inserted)
// or the test context times out.
runCtx, runCancel := context.WithCancel(ctx)
done := make(chan struct{})
go func() {
_ = w.Run(runCtx)
close(done)
}()
// Poll for completion.
deadline := time.Now().Add(60 * time.Second)
var status, imageID string
for time.Now().Before(deadline) {
err = insertConn.QueryRow(ctx,
`SELECT status, COALESCE(image_id::text,'') FROM imagen.jobs WHERE id=$1`,
jobID).Scan(&status, &imageID)
if err != nil {
t.Fatalf("poll: %v", err)
}
if status == "done" || status == "failed" {
break
}
time.Sleep(500 * time.Millisecond)
}
runCancel()
<-done
if status != "done" {
var errText string
_ = insertConn.QueryRow(ctx,
`SELECT COALESCE(error,'') FROM imagen.jobs WHERE id=$1`, jobID).Scan(&errText)
t.Fatalf("job not done within timeout: status=%q error=%q", status, errText)
}
if imageID == "" {
t.Fatalf("job done but image_id is empty")
}
t.Logf("job done: image_id=%s", imageID)
}

View File

@@ -7,7 +7,7 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
```
┌───────────────────────┐
│ cmd/imagen │ CLI dispatch
│ cmd/imagen │ CLI dispatch (generate / worker / …)
│ (or HTTP server) │
└──────────┬────────────┘
@@ -16,6 +16,9 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
│ internal/output │ filename templating, sidecar
│ internal/config │ YAML loader, validation
│ internal/preview │ tmux-img window spawner
│ internal/cloud │ Supabase Storage + imagen.images
│ internal/usage │ mai.imagen_usage cost-tracking
│ internal/worker │ imagen.jobs queue consumer
└──────────┬────────────┘
┌──────────▼────────────┐
@@ -103,9 +106,37 @@ contains the prompt, backend instance name, seed, ISO timestamp, and the
- Network errors during `Generate` — wrap and return; no retry policy yet
(decide per-adapter, or move to a shared retry helper if a pattern emerges).
## Async write path: `imagen worker` + `imagen.jobs`
`imagen generate` is the synchronous CLI. For web callers (flexsiebels'
owner-mode UI) `cmd/imagen worker` runs as a daemon that consumes the
`imagen.jobs` table.
```
flexsiebels POST imagen worker (mRiver, systemd)
→ INSERT INTO LISTEN imagen_jobs ◄── pg_notify trigger
imagen.jobs(pending) claim row (UPDATE … RETURNING)
dispatch through internal/backend
write disk + cloud-sync via internal/cloud
UPDATE imagen.jobs SET status='done', image_id=…
```
The queue table lives next to `imagen.images` in the same `imagen` schema.
Owner-scoped RLS lets the flexsiebels user INSERT + read their own rows;
the worker writes (status updates + image_id link) via service-role which
bypasses RLS. A 5-second safety poll fires on every wake-up to cover
dropped NOTIFY events and worker cold starts with a non-empty queue. See
`docs/setup-worker-mriver.md` for the systemd installation.
The worker reuses `internal/backend`, `internal/output`, and
`internal/cloud` unchanged — it is purely an orchestration layer around
the same pipeline `imagen generate` drives.
## Out of scope (today)
- Image post-processing (cropping, watermarking).
- Cost-tracking (lands with the Replicate adapter, since only API backends bill).
- Multi-image `n>1` per request — backends that support it can expose it via
`BackendOpts`; the framework doesn't have a first-class field yet.
- Job cancellation / kill switch — separate follow-up issue.
- Concurrent workers / multi-host scale-out — `FOR UPDATE SKIP LOCKED` in
the claim query makes it cheap to add, but a single worker is the v1 setup.

View File

@@ -0,0 +1,97 @@
# `imagen worker` on mRiver
The worker is a long-running daemon that consumes the `imagen.jobs` queue
(written by flexsiebels' owner-mode UI) and writes the resulting image to
Supabase Storage + `imagen.images` via the same cloud-sync path the CLI
`imagen generate` uses.
## Architecture
```
flexsiebels (owner UI)
|
v INSERT INTO imagen.jobs (...)
|
msupabase Postgres
|
| AFTER INSERT trigger:
| pg_notify('imagen_jobs', NEW.id)
v
imagen worker (mRiver) ── LISTEN imagen_jobs
|
| 1. claim oldest 'pending' row (status='running')
| 2. dispatch to backend (FLUX schnell local / FLUX dev replicate / …)
| 3. write PNG to disk
| 4. upload to Storage + INSERT into imagen.images
| 5. UPDATE imagen.jobs SET status='done', image_id=...
v
flexsiebels polls GET .../jobs/<id> → renders the rendered card
```
A 5-second safety poll covers dropped NOTIFY events and worker cold starts
with a non-empty queue.
## One-time setup
```bash
# 1. Build the binary (or `task build`).
cd ~/dev/ImaGen
go build -o bin/imagen ./cmd/imagen
# 2. Write the environment file.
cp scripts/imagen-worker.env.example ~/.dotfiles/.env.imagen-worker
chmod 600 ~/.dotfiles/.env.imagen-worker
$EDITOR ~/.dotfiles/.env.imagen-worker # fill in real DSN, service key
# 3. Install the user systemd unit.
mkdir -p ~/.config/systemd/user
cp scripts/imagen-worker.service ~/.config/systemd/user/imagen-worker.service
systemctl --user daemon-reload
systemctl --user enable --now imagen-worker.service
# 4. Tail the logs.
journalctl --user -u imagen-worker -f
```
## Required env vars
See `scripts/imagen-worker.env.example` for the canonical list. Required:
- `IMAGEN_WORKER_DATABASE_URL` — direct Postgres DSN. PostgREST cannot LISTEN.
- `SUPABASE_URL`, `SUPABASE_SERVICE_KEY` — same pair `imagen generate`
reads for the cloud-sync writer.
- `IMAGEN_OWNER_USER_ID` — fallback owner UUID; per-job row's
`owner_user_id` overrides this.
Optional, depending on enabled backends:
- `REPLICATE_API_TOKEN` if any job will request a Replicate-typed backend.
## Operating
```bash
systemctl --user status imagen-worker # health
systemctl --user restart imagen-worker # pick up a new binary
journalctl --user -u imagen-worker -n 200 # recent log lines
```
On startup the worker calls `ResetStaleRunning` once, flipping any rows
left in `'running'` from a previous crash back to `'pending'` so they get
re-claimed by the 5-second poll.
## Smoke test
With the worker running, INSERT a test job:
```sql
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
VALUES (
'ac6c9501-3757-4a6d-8b97-2cff4288382b',
'a tiny owl wearing wire-rim glasses, photo',
'flux-schnell-local', 1024, 1024
);
```
Within ~10 seconds the row should show `status='done'`, a populated
`image_id` linking to a real `imagen.images` row, and a Storage object at
`<YYYY-MM-DD>/<slug>-<seed>.png` in the `imagen-generated` bucket.

View File

@@ -26,6 +26,7 @@ imagen version print version
| `--no-sidecar` | `false` | Skip the JSON sidecar even if config enables it |
| `--preview` | (auto) | Force open a tmux preview window via `tmux-img` |
| `--no-preview` | (auto) | Suppress the preview window (use for batch / CI callers) |
| `--no-cloud` | `false` | Skip Supabase upload + `imagen.images` insert for this call |
| `--config` | `~/.config/imagen.yaml` | Override config path |
### Preview window
@@ -91,3 +92,53 @@ API-backed adapters read tokens from env vars referenced by the config
export REPLICATE_API_TOKEN=...
imagen generate "a cat" --backend flux-dev-replicate
```
## Cost-tracking (Replicate)
Successful generations through the Replicate adapter write one row to
`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost
estimate, prompt sha256 hash (never the prompt itself), and the caller
identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`).
The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are
unset, or the database write fails, the image still lands and the CLI
prints a warning to stderr.
Inspect spend:
```sh
imagen usage # all rows, grouped by week + backend + model + caller
imagen usage --since 2026-05-01 # only rows on/after a UTC date
imagen usage --since 2026-05-01 --raw
```
Per-model rates live in `internal/backend/replicate_pricing.go` — they
are snapshotted from <https://replicate.com/pricing> and refreshed on a
quarterly cadence.
## Cloud-sync (Supabase)
Successful generations also upload the PNG to the private Supabase
Storage bucket `imagen-generated` (path: `<YYYY-MM-DD>/<slug>-<seed>.png`)
and insert a row into `imagen.images`. The row carries the prompt,
sha256-hashed prompt, backend, model, seed/steps/width/height, latency,
cost estimate, the full local sidecar JSON, and an empty `tags` array
ready for the flexsiebels viewer to fill in.
Configuration:
- `owner_user_id` in `imagen.yaml` — m's `auth.users.id`. Empty disables
inserts (the column is `NOT NULL`).
- `output.cloud_sync` in `imagen.yaml`: `auto` (default — on iff
SUPABASE creds + `owner_user_id` are set), `on` (errors if either is
missing), `off`.
- `IMAGEN_CLOUD_SYNC=auto|on|off` overrides config.
- `--no-cloud` overrides everything for one call.
Reuses the same Supabase env (`SUPABASE_URL` + `SUPABASE_SERVICE_KEY` or
`MAI_SUPABASE_KEY`) as cost-tracking. Service-role bypasses RLS for
inserts; the `owner_user_id = auth.uid()` policy on the table gates the
read path the flexsiebels viewer hits.
Failures (Storage 5xx, DB unreachable) emit `imagen: cloud sync: <err>`
to stderr and the local PNG + sidecar stay put. Exit code is unchanged.

15
go.mod
View File

@@ -1,5 +1,16 @@
module mgit.msbls.de/m/ImaGen
go 1.24
go 1.25.0
require gopkg.in/yaml.v3 v3.0.1
require (
github.com/jackc/pgx/v5 v5.9.2
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/text v0.29.0 // indirect
)

33
go.sum
View File

@@ -1,4 +1,35 @@
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,567 @@
package backend
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"time"
)
// ReplicateType is the type-name adapters register under for Replicate
// instances.
const ReplicateType = "replicate"
// Replicate is the Replicate API adapter. It speaks the public REST API
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
// to submit, then polls /v1/predictions/{id} until the prediction
// settles, then downloads the produced image.
//
// Concurrency: a single Replicate is safe to share across goroutines.
type Replicate struct {
instance string
apiBase string
apiToken string
tokenEnv string
model string // "owner/name" or "owner/name:version-hash"
owner string
name string
version string // optional; empty means "use the model-based predictions endpoint"
defaultSteps int
defaultAspect string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
initialBackoff time.Duration
// Sink is where successful generations are recorded for cost-tracking.
// nil means "do not record". The framework wires this up in the CLI;
// adapter tests inject a fake.
Sink UsageSink
}
// UsageSink writes one row per successful generation. Implementations
// should treat write failures as warnings, not errors — the image has
// already landed on disk; failing the call would lose the artefact.
type UsageSink interface {
Record(ctx context.Context, row UsageRow) error
}
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
// prompt itself is intentionally NOT included — only the sha256 hash.
type UsageRow struct {
Backend string
Model string
Seed *int64
PromptHash string
LatencyMs int
CostUSDEstimate *float64
Caller string
}
// NewReplicate is the registry constructor. cfg is the adapter's slice
// of imagen.yaml.
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("replicate: empty instance name")
}
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
}
owner, modelName, version, err := parseModelRef(model)
if err != nil {
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
}
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
pollTimeout := timeoutForModel(modelName)
r := &Replicate{
instance: name,
apiBase: apiBase,
tokenEnv: tokenEnv,
apiToken: os.Getenv(tokenEnv),
model: model,
owner: owner,
name: modelName,
version: version,
defaultSteps: getInt(cfg, "default_steps", 0),
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
httpClient: &http.Client{Timeout: 30 * time.Second},
pollInterval: 500 * time.Millisecond,
pollTimeout: pollTimeout,
randSeed: cryptoSeed,
initialBackoff: time.Second,
}
return r, nil
}
// Name returns the instance name from imagen.yaml.
func (r *Replicate) Name() string { return r.instance }
// Generate submits one prediction and returns the resulting PNG.
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
if r.apiToken == "" {
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
}
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
aspect := computeAspectRatio(width, height, r.defaultAspect)
steps := orDefaultInt(req.Steps, r.defaultSteps)
seed := req.Seed
if seed == 0 {
seed = r.randSeed()
}
input := map[string]any{
"prompt": req.Prompt,
"aspect_ratio": aspect,
"num_outputs": 1,
"output_format": "png",
"seed": seed,
}
if steps > 0 {
input["num_inference_steps"] = steps
}
if req.NegativePrompt != "" {
input["negative_prompt"] = req.NegativePrompt
}
maps.Copy(input, req.BackendOpts)
start := time.Now()
pred, err := r.submitPrediction(ctx, input)
if err != nil {
return nil, err
}
final, err := r.waitForCompletion(ctx, pred.ID)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgURL, err := pickFirstOutputURL(final.Output)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
if err != nil {
return nil, err
}
latencyMs := int(time.Since(start).Milliseconds())
predictTime := final.Metrics.PredictTime
costEst, costKnown := replicatePerImageUSD(r.model)
meta := map[string]any{
"backend": r.instance,
"backend_type": ReplicateType,
"model": r.model,
"model_version": final.Version,
"prediction_id": pred.ID,
"seed": seed,
"width": width,
"height": height,
"aspect_ratio": aspect,
"latency_ms": int64(latencyMs),
}
if predictTime > 0 {
meta["predict_time_seconds"] = predictTime
}
if costKnown {
meta["cost_usd_estimate"] = costEst
}
if r.Sink != nil {
row := UsageRow{
Backend: r.instance,
Model: r.model,
PromptHash: hashPrompt(req.Prompt),
LatencyMs: latencyMs,
Caller: ResolveCaller(),
}
row.Seed = new(int64)
*row.Seed = seed
if costKnown {
c := costEst
row.CostUSDEstimate = &c
}
if err := r.Sink.Record(ctx, row); err != nil {
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
}
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: mime,
Metadata: meta,
}, nil
}
// replicatePrediction is what the REST API returns under /v1/predictions
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
// declared.
type replicatePrediction struct {
ID string `json:"id"`
Status string `json:"status"`
Version string `json:"version"`
Error json.RawMessage `json:"error"`
Output json.RawMessage `json:"output"`
Metrics struct {
PredictTime float64 `json:"predict_time"`
} `json:"metrics"`
}
// submitPrediction creates a prediction and returns it. Picks the
// model-based endpoint when no version was given (recommended for
// Replicate's official models), and the legacy /v1/predictions otherwise.
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
var (
reqURL string
body []byte
err error
)
if r.version == "" {
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
body, err = json.Marshal(map[string]any{"input": input})
} else {
reqURL = r.apiBase + "/v1/predictions"
body, err = json.Marshal(map[string]any{
"version": r.version,
"input": input,
})
}
if err != nil {
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
}
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(respBody, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
}
if pred.ID == "" {
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
}
return &pred, nil
}
// waitForCompletion polls /v1/predictions/{id} until the status is a
// terminal value (succeeded, failed, canceled) or the timeout fires.
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
deadline := time.Now().Add(r.pollTimeout)
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
start := time.Now()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if time.Now().After(deadline) {
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
}
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(body, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
}
switch pred.Status {
case "succeeded":
return &pred, nil
case "failed":
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
case "canceled":
return nil, fmt.Errorf("status=canceled")
case "starting", "processing", "":
default:
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(r.pollInterval):
}
}
}
// doWithRetry executes a Replicate API request and applies the resilience
// policy: 401 surfaces a clean message naming the env var; 429 retries
// with exponential backoff up to three times; 5xx retries once; other
// errors surface unchanged.
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
const max429Retries = 3
backoff := r.initialBackoff
if backoff <= 0 {
backoff = time.Second
}
var lastErr error
for attempt := 0; ; attempt++ {
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Token "+r.apiToken)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
}
lastErr = err
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return respBody, nil
case resp.StatusCode == http.StatusUnauthorized:
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
case resp.StatusCode == http.StatusTooManyRequests:
if attempt >= max429Retries {
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
}
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
if !sleepCtx(ctx, wait) {
return nil, ctx.Err()
}
backoff *= 2
continue
case resp.StatusCode >= 500:
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
default:
_ = lastErr
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
}
}
// fetchImage downloads the rendered image from the Replicate-provided
// CDN URL. One retry on a generic network error.
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
var lastErr error
for attempt := range 2 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
if err != nil {
return nil, "", err
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, "", err
}
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
lastErr = readErr
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if resp.StatusCode >= 500 && attempt == 0 {
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
continue
}
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
}
mime := resp.Header.Get("Content-Type")
if mime == "" {
mime = "image/png"
}
return body, mime, nil
}
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
}
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
func parseModelRef(ref string) (owner, name, version string, err error) {
rest := ref
if i := strings.IndexByte(rest, ':'); i >= 0 {
version = rest[i+1:]
rest = rest[:i]
}
parts := strings.SplitN(rest, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
}
return parts[0], parts[1], version, nil
}
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
// than schnell; everything else gets the dev timeout for safety.
func timeoutForModel(name string) time.Duration {
switch strings.ToLower(name) {
case "flux-schnell":
return 60 * time.Second
default:
return 120 * time.Second
}
}
// pickFirstOutputURL extracts the first image URL from the output field.
// Replicate returns either a single string or an array of strings for image
// models — we accept both.
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", fmt.Errorf("output is empty")
}
var s string
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
return s, nil
}
var arr []string
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
return arr[0], nil
}
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
}
// computeAspectRatio reduces width:height to a Replicate-supported aspect
// ratio when the reduction lands on one of the canonical values; otherwise
// returns fallback.
func computeAspectRatio(w, h int, fallback string) string {
if w <= 0 || h <= 0 {
return fallback
}
g := gcd(w, h)
a, b := w/g, h/g
s := fmt.Sprintf("%d:%d", a, b)
if isReplicateAspectRatio(s) {
return s
}
return fallback
}
func isReplicateAspectRatio(s string) bool {
switch s {
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
return true
}
return false
}
func gcd(a, b int) int {
for b != 0 {
a, b = b, a%b
}
if a < 0 {
return -a
}
return a
}
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
// is intentionally never written to the cost-tracking table.
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// ResolveCaller returns the agent identity the cost-tracking row is
// attributed to. Order of resolution mirrors the maimcp identity logic:
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
// "unknown".
func ResolveCaller() string {
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
return v
}
if pane := os.Getenv("TMUX_PANE"); pane != "" {
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
if err == nil {
if name := strings.TrimSpace(string(out)); name != "" {
return name
}
}
}
return "unknown"
}
// backoffFor429 honours a Retry-After header (in seconds) when present
// and within reason, otherwise falls back to the caller's backoff.
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
if retryAfter == "" {
return fallback
}
d, err := time.ParseDuration(retryAfter + "s")
if err != nil || d <= 0 {
return fallback
}
if d > 30*time.Second {
return 30 * time.Second
}
return d
}
func sleepCtx(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
func bytesReader(b []byte) io.Reader {
if b == nil {
return nil
}
return bytes.NewReader(b)
}
// shortPath strips the host so error messages don't leak the API base.
func shortPath(u string) string {
if i := strings.Index(u, "/v1/"); i >= 0 {
return u[i:]
}
return u
}
func init() {
Register(ReplicateType, NewReplicate)
}

View File

@@ -0,0 +1,42 @@
package backend
import "strings"
// Replicate pricing snapshot.
//
// Source: https://replicate.com/pricing and the per-model "Run" tab on
// each model page. Replicate bills per second of GPU time, but the
// black-forest-labs FLUX models also publish a flat per-image price for
// the typical settings — that flat number is what we hardcode here.
//
// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the
// rates drift more than ~10%, update the table and bump snapshotDate.
const replicatePricingSnapshotDate = "2026-05-08"
// replicatePerImageUSD is the per-image cost estimate keyed by Replicate
// model identifier ("owner/name", with any ":version" trimmed). Returns
// the rate and true if the model is known, 0 and false otherwise — an
// unknown model writes a row with NULL cost rather than a wrong number.
func replicatePerImageUSD(model string) (float64, bool) {
key := normalisePricingKey(model)
switch key {
case "black-forest-labs/flux-schnell":
return 0.003, true
case "black-forest-labs/flux-dev":
return 0.025, true
case "black-forest-labs/flux-pro":
return 0.055, true
case "black-forest-labs/flux-1.1-pro":
return 0.040, true
}
return 0, false
}
// normalisePricingKey strips the optional ":version" suffix and lowercases
// the owner/name pair. "Owner/Name:hash" → "owner/name".
func normalisePricingKey(model string) string {
if i := strings.IndexByte(model, ':'); i >= 0 {
model = model[:i]
}
return strings.ToLower(model)
}

View File

@@ -0,0 +1,675 @@
package backend
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// fakeReplicate is a programmable mock of the Replicate REST API.
type fakeReplicate struct {
t *testing.T
mu sync.Mutex
// Responses for the predictions submission. If the request matches the
// model-based path, modelEndpointHits increments; for /v1/predictions,
// versionEndpointHits increments.
createStatus int
createBody []byte
createCalls atomic.Int32
// Auth-fail policy: if true, every request to the prediction endpoints
// returns 401 with a stock body.
auth401 bool
// 429-then-OK policy: first N create calls return 429.
create429Until int32
retryAfter string
// Sequence of /predictions/{id} responses, walked in order; once
// exhausted, the last entry is returned indefinitely.
pollResponses []string
pollIdx atomic.Int32
pollCalls atomic.Int32
// Image download policy.
imageStatus int
imageBody []byte
image5xxFirst int32
imageCalls atomic.Int32
server *httptest.Server
}
func (f *fakeReplicate) handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
f.handleCreate(w, r)
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
f.handleCreate(w, r)
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
f.handlePoll(w, r)
case r.Method == http.MethodGet && r.URL.Path == "/img":
f.handleImage(w, r)
default:
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
http.NotFound(w, r)
}
})
}
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
n := f.createCalls.Add(1)
if f.auth401 {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
return
}
if n <= f.create429Until {
if f.retryAfter != "" {
w.Header().Set("Retry-After", f.retryAfter)
}
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
return
}
w.WriteHeader(f.createStatus)
_, _ = w.Write(f.createBody)
}
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
f.pollCalls.Add(1)
idx := int(f.pollIdx.Add(1)) - 1
if idx >= len(f.pollResponses) {
idx = len(f.pollResponses) - 1
}
if idx < 0 {
http.Error(w, "no poll response configured", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(f.pollResponses[idx]))
}
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
n := f.imageCalls.Add(1)
if n <= f.image5xxFirst {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte("upstream unavailable"))
return
}
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(f.imageStatus)
_, _ = w.Write(f.imageBody)
}
func (f *fakeReplicate) start() {
f.server = httptest.NewServer(f.handler())
f.t.Cleanup(f.server.Close)
}
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
func newFakeReplicate(t *testing.T) *fakeReplicate {
t.Helper()
f := &fakeReplicate{
t: t,
createStatus: http.StatusCreated,
imageStatus: http.StatusOK,
imageBody: mustPNG(t, 8, 8),
}
f.start()
// Default happy-path responses now that the server URL is known.
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
}
return f
}
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
t.Helper()
be, err := NewReplicate("flux-test", map[string]any{
"api_token_env": "TEST_REPLICATE_TOKEN",
"model": model,
"api_base": f.server.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake-token"
r.pollInterval = time.Millisecond
r.pollTimeout = 5 * time.Second
r.initialBackoff = time.Millisecond
r.randSeed = func() int64 { return 42 }
return r
}
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
t.Errorf("expected error for empty instance name")
}
if _, err := NewReplicate("x", map[string]any{}); err == nil {
t.Errorf("expected error for missing model")
}
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
t.Errorf("expected error for malformed model")
}
}
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.apiToken = "" // simulate the env var being unset
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error when token is missing")
}
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name the env var: %v", err)
}
}
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
f := newFakeReplicate(t)
sink := &recordingSink{}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sink
res, err := r.Generate(context.Background(), Request{
Prompt: "a tiny dragon",
Width: 1024, Height: 1024,
Seed: 0,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
defer res.ImageReader.Close()
body, _ := io.ReadAll(res.ImageReader)
if !bytes.Equal(body, f.imageBody) {
t.Errorf("image body did not round-trip")
}
if mime := res.MimeType; mime != "image/png" {
t.Errorf("mime = %q", mime)
}
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
t.Errorf("metadata model = %v", got)
}
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
t.Errorf("metadata model_version = %v", got)
}
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
t.Errorf("metadata predict_time_seconds = %v", got)
}
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
}
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
t.Errorf("aspect_ratio = %q", got)
}
if got := f.pollCalls.Load(); got < 3 {
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
row := sink.rows[0]
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
}
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
}
if len(row.PromptHash) != 64 {
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
}
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
}
if row.LatencyMs <= 0 {
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
}
}
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
f := newFakeReplicate(t)
var sentBody []byte
var sentPath string
mu := sync.Mutex{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
mu.Lock()
sentBody, _ = io.ReadAll(r.Body)
sentPath = r.URL.Path
mu.Unlock()
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
default:
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev:abc123",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
if err != nil {
t.Fatalf("Generate: %v", err)
}
res.ImageReader.Close()
mu.Lock()
defer mu.Unlock()
if sentPath != "/v1/predictions" {
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
}
var body map[string]any
if err := json.Unmarshal(sentBody, &body); err != nil {
t.Fatalf("unmarshal sent body: %v", err)
}
if body["version"] != "abc123" {
t.Errorf("expected version=abc123 in body, got %v", body["version"])
}
}
func TestReplicate401SurfacesEnvHint(t *testing.T) {
f := newFakeReplicate(t)
f.auth401 = true
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected 401 to surface as error")
}
msg := err.Error()
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name env var: %v", err)
}
if !strings.Contains(msg, "401") {
t.Errorf("error should mention 401: %v", err)
}
}
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 2 // first two calls 429 then succeed
f.retryAfter = "" // force the adapter's exp-backoff path
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = time.Millisecond
// Squash the backoff so the test is fast.
r.httpClient = &http.Client{Timeout: 5 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := r.Generate(ctx, Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate after 429s: %v", err)
}
res.ImageReader.Close()
if got := f.createCalls.Load(); got != 3 {
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
}
}
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 99 // every call 429
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained 429s")
}
if !strings.Contains(err.Error(), "429") {
t.Errorf("expected 429 in error: %v", err)
}
// max429Retries=3 → 1 initial + 3 retries = 4 total
if got := f.createCalls.Load(); got != 4 {
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
}
}
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error for failed prediction")
}
if !strings.Contains(err.Error(), "failed") {
t.Errorf("error should mention failure: %v", err)
}
if !strings.Contains(err.Error(), "NSFW") {
t.Errorf("error should include the API error message: %v", err)
}
}
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
f := newFakeReplicate(t)
// Always return processing → adapter times out.
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 30 * time.Millisecond
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected timeout error")
}
if !strings.Contains(err.Error(), "did not complete") {
t.Errorf("expected 'did not complete' in error, got %v", err)
}
if !strings.Contains(err.Error(), "waited") {
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
}
}
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 1 // first download 502, second OK
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (download retry): %v", err)
}
res.ImageReader.Close()
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
}
}
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 99 // every download fails
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained image-download 5xx")
}
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
}
}
func TestReplicateContextCancelStopsPolling(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected ctx error")
}
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
t.Errorf("expected deadline exceeded, got %v", err)
}
}
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-schnell",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
// captured input should still have been recorded.
_, _ = r.Generate(context.Background(), Request{
Prompt: "p",
Width: 1024, Height: 1024,
BackendOpts: map[string]any{
"output_quality": 90,
"go_fast": true,
},
})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["output_quality"] != float64(90) {
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
}
if captured["go_fast"] != true {
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
}
if captured["prompt"] != "p" {
t.Errorf("prompt not threaded: %v", captured["prompt"])
}
}
func TestReplicateOutputArrayAccepted(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (output as array): %v", err)
}
res.ImageReader.Close()
}
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "stability-ai/sdxl")
sink := &recordingSink{}
r.Sink = sink
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (unknown model): %v", err)
}
res.ImageReader.Close()
if _, present := res.Metadata["cost_usd_estimate"]; present {
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
if sink.rows[0].CostUSDEstimate != nil {
t.Errorf("expected nil cost in sink row for unknown model")
}
}
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("sink failure should not fail Generate: %v", err)
}
res.ImageReader.Close()
}
func TestReplicateDefaultStepsApplied(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev",
"api_base": srv.URL,
"default_steps": 28,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["num_inference_steps"] != float64(28) {
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
}
}
func TestComputeAspectRatio(t *testing.T) {
cases := []struct {
w, h int
fallback string
want string
}{
{1024, 1024, "1:1", "1:1"},
{1920, 1080, "1:1", "16:9"},
{2560, 1440, "1:1", "16:9"},
{1024, 768, "1:1", "4:3"},
{1024, 1280, "1:1", "4:5"},
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
{0, 1024, "1:1", "1:1"},
}
for _, c := range cases {
got := computeAspectRatio(c.w, c.h, c.fallback)
if got != c.want {
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
}
}
}
func TestParseModelRef(t *testing.T) {
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
owner, name, ver, err = parseModelRef("owner/name:hash123")
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
if _, _, _, err := parseModelRef("noslash"); err == nil {
t.Errorf("expected error for malformed ref")
}
}
func TestHashPromptStable(t *testing.T) {
a := hashPrompt("hello")
b := hashPrompt("hello")
c := hashPrompt("hello!")
if a != b {
t.Errorf("hashPrompt should be deterministic")
}
if a == c {
t.Errorf("different prompts should hash differently")
}
if len(a) != 64 {
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
}
}
func TestReplicatePricingKnownModels(t *testing.T) {
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
t.Errorf("dev rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
}
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
t.Errorf("unknown model should report ok=false")
}
}
func TestReplicateTypeIsRegistered(t *testing.T) {
if !Default.Has(ReplicateType) {
t.Errorf("replicate type not registered in Default")
}
}
// recordingSink captures rows for assertion.
type recordingSink struct {
mu sync.Mutex
rows []UsageRow
}
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rows = append(s.rows, row)
return nil
}
type sinkFunc func(context.Context, UsageRow) error
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }

356
internal/cloud/cloud.go Normal file
View File

@@ -0,0 +1,356 @@
// Package cloud syncs a generated image to Supabase Storage and inserts
// a row into imagen.images. Both steps are best-effort: callers log the
// returned error and proceed, because the local PNG + sidecar are already
// on disk by the time Sync runs and a cloud blip should not lose the
// artefact.
//
// The single source of truth for the row schema is the imagen_schema_init
// migration — see internal docs in the issue body for #7.
package cloud
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
)
// supabaseSchema is the PostgREST profile header value the imagen schema
// is exposed under (see ALTER ROLE authenticator SET pgrst.db_schemas).
const supabaseSchema = "imagen"
// bucketName is the Supabase Storage bucket all generated images land in.
const bucketName = "imagen-generated"
// Sink writes one PNG + one row per generation. It is safe to share
// across goroutines.
type Sink struct {
// URL is SUPABASE_URL — e.g. https://supa.flexsiebels.de.
URL string
// APIKey is the service-role key (SUPABASE_SERVICE_KEY). Storage uploads
// and DB inserts both bypass RLS with this key — the policies on the
// table + bucket are the contract for the read side.
APIKey string
// OwnerUserID is m's auth.users.id. It populates owner_user_id on every
// row. Empty means the sink refuses to insert (the column is NOT NULL
// and the user-mode reader needs it for the RLS policy).
OwnerUserID string
// HTTP is the http client; tests inject one pointing at httptest.
HTTP *http.Client
// MaxRetries is the number of additional attempts after the first
// failure for retryable (5xx) responses. Zero means single-shot.
MaxRetries int
// InitialBackoff is the wait before the first retry; doubles per attempt.
// Set very small in tests.
InitialBackoff time.Duration
}
// NewFromEnv returns a sink populated from SUPABASE_URL +
// SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) + IMAGEN_OWNER_USER_ID.
// Returns ok=false if the URL or key are missing — the caller treats that
// as "cloud-sync disabled by environment".
func NewFromEnv() (*Sink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &Sink{
URL: u,
APIKey: key,
OwnerUserID: os.Getenv("IMAGEN_OWNER_USER_ID"),
HTTP: &http.Client{Timeout: 30 * time.Second},
MaxRetries: 2,
InitialBackoff: time.Second,
}, true
}
// SyncRequest is the cross-backend ingredient set Sync needs. Date is
// formatted as YYYY-MM-DD; Slug + Seed are reused from the local
// filename so storage_path mirrors disk layout.
type SyncRequest struct {
Date string
Slug string
Seed int64
Ext string // "png", "jpg", "webp" — no leading dot
PNG []byte
MimeType string
Prompt string
Backend string
Model string
Steps int
Width int
Height int
LatencyMs int
CostUSDEstimate *float64
Sidecar map[string]any
}
// SyncResult tells the caller what landed where.
type SyncResult struct {
StoragePath string // e.g. "2026-05-11/lighthouse-42.png"
ImageID string // imagen.images.id (UUID)
}
// Sync uploads the bytes and inserts the metadata row. Returns the row's
// id and storage_path on success; any non-nil error is what the caller
// surfaces as "imagen: cloud sync: <err>" and otherwise ignores.
func (s *Sink) Sync(ctx context.Context, req SyncRequest) (*SyncResult, error) {
if s == nil {
return nil, fmt.Errorf("cloud sink not configured")
}
if s.OwnerUserID == "" {
return nil, fmt.Errorf("owner_user_id not set (config or $IMAGEN_OWNER_USER_ID); refusing to insert NULL into imagen.images")
}
if req.Date == "" || req.Slug == "" {
return nil, fmt.Errorf("date and slug are required for storage_path")
}
ext := req.Ext
if ext == "" {
ext = "png"
}
storagePath := fmt.Sprintf("%s/%s-%d.%s", req.Date, req.Slug, req.Seed, ext)
if err := s.upload(ctx, storagePath, req.PNG, req.MimeType); err != nil {
return nil, fmt.Errorf("storage upload: %w", err)
}
id, err := s.insertRow(ctx, storagePath, req)
if err != nil {
return &SyncResult{StoragePath: storagePath}, fmt.Errorf("db insert: %w", err)
}
return &SyncResult{StoragePath: storagePath, ImageID: id}, nil
}
// upload PUTs the PNG into the imagen-generated bucket. We use
// Content-Type so signed URLs render in the browser without a download
// prompt. POST would error on second-write; PUT (with x-upsert: true) is
// idempotent for re-runs of the same date+slug+seed.
func (s *Sink) upload(ctx context.Context, storagePath string, body []byte, mime string) error {
if mime == "" {
mime = "image/png"
}
endpoint := fmt.Sprintf("%s/storage/v1/object/%s/%s", s.URL, bucketName, pathEscape(storagePath))
return s.doRetry(ctx, func(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", mime)
req.Header.Set("x-upsert", "true")
return s.HTTP.Do(req)
})
}
// insertRow POSTs to PostgREST against the imagen schema. Prefer:
// return=representation gives us the inserted id back without a second
// round-trip.
func (s *Sink) insertRow(ctx context.Context, storagePath string, req SyncRequest) (string, error) {
row := map[string]any{
"owner_user_id": s.OwnerUserID,
"prompt": req.Prompt,
"prompt_hash": hashPrompt(req.Prompt),
"backend": req.Backend,
"storage_path": storagePath,
}
if req.Model != "" {
row["model"] = req.Model
}
if req.Seed != 0 {
row["seed"] = req.Seed
}
if req.Steps != 0 {
row["steps"] = req.Steps
}
if req.Width != 0 {
row["width"] = req.Width
}
if req.Height != 0 {
row["height"] = req.Height
}
if req.LatencyMs != 0 {
row["latency_ms"] = req.LatencyMs
}
if req.CostUSDEstimate != nil {
row["cost_usd_estimate"] = *req.CostUSDEstimate
}
if len(req.Sidecar) > 0 {
row["sidecar"] = req.Sidecar
}
body, err := json.Marshal(row)
if err != nil {
return "", fmt.Errorf("marshal row: %w", err)
}
endpoint := s.URL + "/rest/v1/images"
respBody, err := s.doRetryRead(ctx, func(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=representation")
return s.HTTP.Do(req)
})
if err != nil {
return "", err
}
var rows []struct {
ID string `json:"id"`
}
if err := json.Unmarshal(respBody, &rows); err != nil {
return "", fmt.Errorf("parse insert response: %w (body: %s)", err, snip(respBody))
}
if len(rows) == 0 {
return "", fmt.Errorf("insert returned 0 rows (body: %s)", snip(respBody))
}
return rows[0].ID, nil
}
// SignedURL asks the Storage API for a time-limited URL. ttlSeconds is
// the validity window. Returned URL is host-qualified and ready to hand
// to a browser.
func (s *Sink) SignedURL(ctx context.Context, storagePath string, ttlSeconds int) (string, error) {
if s == nil {
return "", fmt.Errorf("cloud sink not configured")
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
endpoint := fmt.Sprintf("%s/storage/v1/object/sign/%s/%s", s.URL, bucketName, pathEscape(storagePath))
body, err := json.Marshal(map[string]any{"expiresIn": ttlSeconds})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.HTTP.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("sign %d: %s", resp.StatusCode, snip(respBody))
}
var parsed struct {
SignedURL string `json:"signedURL"`
}
if err := json.Unmarshal(respBody, &parsed); err != nil {
return "", fmt.Errorf("parse sign response: %w (body: %s)", err, snip(respBody))
}
if parsed.SignedURL == "" {
return "", fmt.Errorf("empty signedURL in response: %s", snip(respBody))
}
full := parsed.SignedURL
if strings.HasPrefix(full, "/") {
full = s.URL + full
}
return full, nil
}
// doRetry runs op up to MaxRetries+1 times. 5xx and transport errors are
// retried with exponential backoff; 4xx surfaces immediately as a
// permanent error (caller's bug in the row, not a network blip).
func (s *Sink) doRetry(ctx context.Context, op func(context.Context) (*http.Response, error)) error {
_, err := s.doRetryRead(ctx, op)
return err
}
// doRetryRead is the read-the-body variant. Returns the 2xx response
// body bytes; non-2xx is wrapped in an error. Same retry semantics as
// doRetry: 5xx/transport retries with exponential backoff, 4xx is fatal.
func (s *Sink) doRetryRead(ctx context.Context, op func(context.Context) (*http.Response, error)) ([]byte, error) {
backoff := s.InitialBackoff
if backoff == 0 {
backoff = time.Second
}
attempts := s.MaxRetries + 1
if attempts < 1 {
attempts = 1
}
var lastErr error
for i := 0; i < attempts; i++ {
if i > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(backoff):
}
backoff *= 2
}
resp, err := op(ctx)
if err != nil {
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
resp.Body.Close()
if readErr != nil {
lastErr = fmt.Errorf("read body: %w", readErr)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return body, nil
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
}
return nil, lastErr
}
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// pathEscape encodes each path segment but keeps the slashes — the
// Storage API treats the part after the bucket name as a virtual file
// path with directory separators.
func pathEscape(p string) string {
parts := strings.Split(p, "/")
for i, seg := range parts {
parts[i] = url.PathEscape(seg)
}
return path.Join(parts...)
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}

View File

@@ -0,0 +1,326 @@
package cloud
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"
)
// fakeSupabase is a tiny stand-in for Supabase Storage + PostgREST. It
// records what came in and returns canned responses based on path.
type fakeSupabase struct {
t *testing.T
mux *http.ServeMux
server *httptest.Server
uploadCalls int32
insertCalls int32
uploadBytes []byte
uploadHdr http.Header
insertBody []byte
insertHdr http.Header
}
func newFakeSupabase(t *testing.T, opts ...func(*fakeSupabase)) *fakeSupabase {
f := &fakeSupabase{t: t}
f.mux = http.NewServeMux()
// Storage upload — anything under /storage/v1/object/<bucket>/...
f.mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&f.uploadCalls, 1)
body, _ := io.ReadAll(r.Body)
f.uploadBytes = body
f.uploadHdr = r.Header.Clone()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"Key":"imagen-generated/somepath"}`))
})
// Storage sign URL
f.mux.HandleFunc("/storage/v1/object/sign/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"signedURL":"/storage/v1/object/sign/imagen-generated/some.png?token=abc"}`))
})
// PostgREST insert
f.mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&f.insertCalls, 1)
body, _ := io.ReadAll(r.Body)
f.insertBody = body
f.insertHdr = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`[{"id":"00000000-0000-0000-0000-000000000abc"}]`))
})
for _, opt := range opts {
opt(f)
}
f.server = httptest.NewServer(f.mux)
t.Cleanup(f.server.Close)
return f
}
func newSink(server *httptest.Server) *Sink {
return &Sink{
URL: server.URL,
APIKey: "fake-service-key",
OwnerUserID: "00000000-0000-0000-0000-000000000001",
HTTP: server.Client(),
MaxRetries: 2,
InitialBackoff: time.Millisecond,
}
}
func TestSyncHappyPath(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
cost := 0.003
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11",
Slug: "lighthouse",
Seed: 42,
Ext: "png",
PNG: []byte("PNGbytes"),
MimeType: "image/png",
Prompt: "a tiny lighthouse on a stormy cliff",
Backend: "flux-schnell-local",
Model: "flux1-schnell",
Steps: 4,
Width: 1024,
Height: 1024,
LatencyMs: 1500,
CostUSDEstimate: &cost,
Sidecar: map[string]any{
"timestamp": "2026-05-11T01:30:00Z",
"backend": "flux-schnell-local",
},
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
if res.StoragePath != "2026-05-11/lighthouse-42.png" {
t.Errorf("storage_path = %q", res.StoragePath)
}
if res.ImageID != "00000000-0000-0000-0000-000000000abc" {
t.Errorf("image_id = %q", res.ImageID)
}
if got := atomic.LoadInt32(&f.uploadCalls); got != 1 {
t.Errorf("upload calls = %d, want 1", got)
}
if got := atomic.LoadInt32(&f.insertCalls); got != 1 {
t.Errorf("insert calls = %d, want 1", got)
}
if !bytes.Equal(f.uploadBytes, []byte("PNGbytes")) {
t.Errorf("uploaded bytes = %q", f.uploadBytes)
}
// Verify the row payload carries the prompt + computed hash + non-zero
// metadata. Empty fields should be omitted from the JSON body so RLS
// won't see surprise keys.
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("insert body parse: %v\n%s", err, f.insertBody)
}
if row["prompt"] != "a tiny lighthouse on a stormy cliff" {
t.Errorf("row.prompt = %v", row["prompt"])
}
if row["owner_user_id"] != "00000000-0000-0000-0000-000000000001" {
t.Errorf("row.owner_user_id = %v", row["owner_user_id"])
}
if row["storage_path"] != "2026-05-11/lighthouse-42.png" {
t.Errorf("row.storage_path = %v", row["storage_path"])
}
hash, _ := row["prompt_hash"].(string)
if len(hash) != 64 {
t.Errorf("prompt_hash should be 64-char sha256 hex, got %q", hash)
}
if row["backend"] != "flux-schnell-local" {
t.Errorf("row.backend = %v", row["backend"])
}
if row["seed"].(float64) != 42 {
t.Errorf("row.seed = %v", row["seed"])
}
if row["latency_ms"].(float64) != 1500 {
t.Errorf("row.latency_ms = %v", row["latency_ms"])
}
if row["cost_usd_estimate"].(float64) != 0.003 {
t.Errorf("row.cost = %v", row["cost_usd_estimate"])
}
if row["sidecar"] == nil {
t.Errorf("row.sidecar missing")
}
// PostgREST schema headers — hardcoded to "imagen".
if got := f.insertHdr.Get("Accept-Profile"); got != "imagen" {
t.Errorf("Accept-Profile = %q", got)
}
if got := f.insertHdr.Get("Content-Profile"); got != "imagen" {
t.Errorf("Content-Profile = %q", got)
}
if got := f.insertHdr.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") {
t.Errorf("Authorization = %q", got)
}
// Storage upsert should be set so re-runs of the same date+slug+seed
// don't fail with 409.
if got := f.uploadHdr.Get("x-upsert"); got != "true" {
t.Errorf("x-upsert = %q", got)
}
}
func TestSyncRetryOn5xx(t *testing.T) {
var uploadAttempts int32
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&uploadAttempts, 1)
// Two 503s, then OK.
if n < 3 {
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`[{"id":"row-id"}]`))
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err != nil {
t.Fatalf("Sync (with retry): %v", err)
}
if got := atomic.LoadInt32(&uploadAttempts); got != 3 {
t.Errorf("upload attempts = %d, want 3", got)
}
if res.ImageID != "row-id" {
t.Errorf("image_id = %q", res.ImageID)
}
}
func TestSyncNoRetryOn4xx(t *testing.T) {
var uploadAttempts int32
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&uploadAttempts, 1)
http.Error(w, `{"message":"bad request"}`, http.StatusBadRequest)
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error on 400")
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("error should mention 400 status: %v", err)
}
if got := atomic.LoadInt32(&uploadAttempts); got != 1 {
t.Errorf("upload attempts = %d, want 1 (no retry on 4xx)", got)
}
}
func TestSyncMissingOwnerUserID(t *testing.T) {
srv := httptest.NewServer(http.NewServeMux())
defer srv.Close()
s := &Sink{
URL: srv.URL,
APIKey: "k",
// OwnerUserID intentionally empty.
HTTP: srv.Client(),
InitialBackoff: time.Millisecond,
}
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error when owner_user_id unset")
}
if !strings.Contains(err.Error(), "owner_user_id") {
t.Errorf("error should mention owner_user_id: %v", err)
}
}
func TestSyncRequiresDateAndSlug(t *testing.T) {
srv := httptest.NewServer(http.NewServeMux())
defer srv.Close()
s := newSink(srv)
_, err := s.Sync(context.Background(), SyncRequest{
Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error for missing date")
}
}
func TestSignedURL(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
got, err := s.SignedURL(context.Background(), "2026-05-11/x.png", 60)
if err != nil {
t.Fatalf("SignedURL: %v", err)
}
want := f.server.URL + "/storage/v1/object/sign/imagen-generated/some.png?token=abc"
if got != want {
t.Errorf("signed URL = %q, want %q", got, want)
}
}
func TestSyncDBFailureSurfacesPathOnError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "schema cache miss", http.StatusInternalServerError)
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 9, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error from DB insert failure")
}
// Storage upload succeeded — caller can still see the upload landed.
if res == nil || res.StoragePath != "2026-05-11/x-9.png" {
t.Errorf("expected storage_path on partial success, got %+v", res)
}
}
func TestPathEscape(t *testing.T) {
cases := map[string]string{
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",
"2026-05-11/two words.png": "2026-05-11/two%20words.png",
"with#hash/and?query.png": "with%23hash/and%3Fquery.png",
}
for in, want := range cases {
got := pathEscape(in)
if got != want {
t.Errorf("pathEscape(%q) = %q, want %q", in, got, want)
}
// Sanity: every part should round-trip via url.PathUnescape.
for _, seg := range strings.Split(got, "/") {
if _, err := url.PathUnescape(seg); err != nil {
t.Errorf("segment %q failed unescape: %v", seg, err)
}
}
}
}

View File

@@ -15,6 +15,10 @@ import (
// Config is the top-level shape of imagen.yaml.
type Config struct {
DefaultBackend string `yaml:"default_backend"`
// OwnerUserID is m's auth.users.id on msupabase. The cloud-sync writer
// uses it to populate imagen.images.owner_user_id (NOT NULL, owns RLS).
// Empty disables DB inserts even when cloud_sync is on.
OwnerUserID string `yaml:"owner_user_id"`
Output OutputConfig `yaml:"output"`
Backends map[string]BackendSpec `yaml:"backends"`
}
@@ -29,6 +33,11 @@ type OutputConfig struct {
// Empty / unset is treated as "auto". $IMAGEN_PREVIEW and the
// --preview/--no-preview flags override this in turn.
Preview string `yaml:"preview"`
// CloudSync controls whether successful generations also upload to
// Supabase Storage and insert into imagen.images. Tri-state mirroring
// Preview: "auto" (default — on when SUPABASE_URL + SUPABASE_SERVICE_KEY
// are set), "on" (errors if env unset), "off". --no-cloud overrides.
CloudSync string `yaml:"cloud_sync"`
}
// BackendSpec is one entry under `backends:`. Type identifies the adapter;
@@ -88,6 +97,11 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("output.preview = %q (must be auto|on|off)", c.Output.Preview)
}
switch c.Output.CloudSync {
case "", "auto", "on", "off":
default:
return fmt.Errorf("output.cloud_sync = %q (must be auto|on|off)", c.Output.CloudSync)
}
for name, spec := range c.Backends {
if name == "" {
return errors.New("empty backend name")
@@ -107,6 +121,11 @@ const Sample = `# imagen.yaml — config for the imagen CLI.
default_backend: flux-schnell-local
# Owner UUID for the cloud-sync row in imagen.images. Look up via:
# SELECT id FROM auth.users WHERE email = '<your-supabase-email>';
# Empty disables imagen.images inserts even when cloud_sync is on.
owner_user_id: ""
output:
directory: ~/Pictures/imagen
naming: "{date}-{slug}-{seed}.png"
@@ -116,6 +135,13 @@ output:
# on: always preview (errors outside a tmux session).
# off: never preview (use this for batch / CI callers).
preview: auto
# Sync the PNG to Supabase Storage (bucket: imagen-generated) and insert
# a row into imagen.images. Reads SUPABASE_URL + SUPABASE_SERVICE_KEY
# from env (same as mai.imagen_usage cost-tracking).
# auto (default): on iff env is configured AND owner_user_id is set.
# on: always upload (errors if env or owner_user_id is missing).
# off: never upload. --no-cloud also forces off per-call.
cloud_sync: auto
backends:
flux-schnell-local:
@@ -131,11 +157,19 @@ backends:
mock:
type: mock
flux-schnell-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-schnell
default_steps: 4
default_aspect_ratio: "1:1"
flux-dev-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-dev
default_steps: 28
default_aspect_ratio: "1:1"
dalle3:
type: openai

View File

@@ -73,6 +73,39 @@ func TestValidatePreviewMode(t *testing.T) {
}
}
func TestValidateCloudSyncMode(t *testing.T) {
for _, mode := range []string{"", "auto", "on", "off"} {
c := &Config{Output: OutputConfig{CloudSync: mode}}
if err := c.Validate(); err != nil {
t.Errorf("cloud_sync=%q: unexpected error %v", mode, err)
}
}
bad := &Config{Output: OutputConfig{CloudSync: "yes"}}
if err := bad.Validate(); err == nil {
t.Errorf("expected error for invalid cloud_sync value")
}
}
func TestSampleParsesCloudSyncAuto(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "imagen.yaml")
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
t.Fatalf("write sample: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.Output.CloudSync != "auto" {
t.Errorf("Output.CloudSync = %q, want auto", cfg.Output.CloudSync)
}
// owner_user_id is intentionally empty in the sample — operators fill
// it in after looking up their auth.users.id.
if cfg.OwnerUserID != "" {
t.Errorf("Sample OwnerUserID should be empty, got %q", cfg.OwnerUserID)
}
}
func TestSampleParsesPreviewAuto(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "imagen.yaml")

View File

@@ -35,6 +35,13 @@ type Inputs struct {
type Outputs struct {
ImagePath string
SidecarPath string
// Date is the YYYY-MM-DD the writer used for the filename. Cloud sync
// reuses this so storage_path matches the local filename's date.
Date string
// Slug is the filename-safe prompt fragment the writer used.
Slug string
// Seed is the seed value baked into the filename.
Seed int64
}
// Write streams img to disk and, if enabled, writes a sidecar. The image
@@ -50,10 +57,12 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
if tmpl == "" {
tmpl = "{date}-{slug}-{seed}.{ext}"
}
date := now.Format("2006-01-02")
slug := Slug(in.Prompt)
name := renderTemplate(tmpl, map[string]string{
"date": now.Format("2006-01-02"),
"date": date,
"time": now.Format("150405"),
"slug": Slug(in.Prompt),
"slug": slug,
"seed": fmt.Sprintf("%d", in.Seed),
"backend": in.Backend,
"ext": strings.TrimPrefix(ext, "."),
@@ -80,7 +89,7 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
return nil, fmt.Errorf("close %s: %w", imagePath, err)
}
out := &Outputs{ImagePath: imagePath}
out := &Outputs{ImagePath: imagePath, Date: date, Slug: slug, Seed: in.Seed}
if w.WriteSidecar {
sidecar := imagePath + ".json"
@@ -122,7 +131,7 @@ func (w *Writer) WriteToPath(img io.Reader, path string, in Inputs) (*Outputs, e
if err := f.Close(); err != nil {
return nil, fmt.Errorf("close %s: %w", path, err)
}
out := &Outputs{ImagePath: path}
out := &Outputs{ImagePath: path, Date: now.Format("2006-01-02"), Slug: Slug(in.Prompt), Seed: in.Seed}
if w.WriteSidecar {
sidecar := path + ".json"
body := map[string]any{

160
internal/usage/usage.go Normal file
View File

@@ -0,0 +1,160 @@
// Package usage records per-call cost-tracking rows for the imagen CLI
// to mai.imagen_usage on Supabase. The writer is best-effort by design —
// the calling adapter logs failures and proceeds, because the image
// itself has already landed on disk by the time we record.
package usage
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
)
// Default REST schema is the mai schema where mai.imagen_usage lives.
const supabaseSchema = "mai"
// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/
// Content-Profile headers to target the mai schema instead of public.
type SupabaseSink struct {
URL string // SUPABASE_URL — e.g. https://msup.msbls.de
APIKey string // SUPABASE_SERVICE_KEY
HTTP *http.Client
}
// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY
// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use.
// Returns nil + ok=false if the env vars are not configured — the CLI
// uses that to skip cost-tracking gracefully.
func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &SupabaseSink{
URL: u,
APIKey: key,
HTTP: &http.Client{Timeout: 10 * time.Second},
}, true
}
type supabaseRow struct {
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed,omitempty"`
PromptHash string `json:"prompt_hash"`
LatencyMs int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"`
Caller string `json:"caller,omitempty"`
}
// Record inserts one row into mai.imagen_usage.
func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error {
body, err := json.Marshal(supabaseRow{
Backend: row.Backend,
Model: row.Model,
Seed: row.Seed,
PromptHash: row.PromptHash,
LatencyMs: row.LatencyMs,
CostUSDEstimate: row.CostUSDEstimate,
Caller: row.Caller,
})
if err != nil {
return fmt.Errorf("usage: marshal: %w", err)
}
endpoint := s.URL + "/rest/v1/imagen_usage"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=minimal")
resp, err := s.HTTP.Do(req)
if err != nil {
return fmt.Errorf("usage: POST: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody))
}
return nil
}
// Row is the read-side row shape (only the fields the CLI needs).
type Row struct {
CreatedAt time.Time `json:"created_at"`
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed"`
PromptHash string `json:"prompt_hash"`
LatencyMs *int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate"`
Caller *string `json:"caller"`
}
// Query returns rows from mai.imagen_usage filtered by created_at >= since.
// Pass zero time to fetch the full table (capped server-side by PostgREST
// — we set a hard 5000-row limit here too).
func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) {
q := url.Values{}
q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller")
q.Set("order", "created_at.desc")
q.Set("limit", "5000")
if !since.IsZero() {
q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339))
}
endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Accept-Profile", supabaseSchema)
resp, err := s.HTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("usage: GET: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body))
}
var rows []Row
if err := json.Unmarshal(body, &rows); err != nil {
return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body))
}
return rows, nil
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}

213
internal/worker/worker.go Normal file
View File

@@ -0,0 +1,213 @@
// Package worker consumes the imagen.jobs queue. It claims pending rows via
// an UPDATE-returning lock (single source of truth, no double-claim window),
// runs the supplied generation pipeline, then writes status + image_id back.
//
// The package is DB-agnostic: it talks to two small interfaces (Queue +
// Pipeline) so unit tests can drive the claim/transition logic with no real
// Postgres connection. cmd/imagen wires the pgx implementation.
package worker
import (
"context"
"errors"
"fmt"
"sync"
"time"
)
// Job is the slice of an imagen.jobs row the worker needs to drive a
// generation. Null columns from the DB are represented as zero values; the
// pipeline treats zero values as "use backend default" (same convention as
// backend.Request).
type Job struct {
ID string
OwnerUserID string
Prompt string
Backend string
Model string
Width int
Height int
Steps int
Seed int64
Style string
}
// Outcome is what the pipeline reports back per job. ImageID is the
// imagen.images.id the cloud-sync produced. Empty ImageID with nil Err means
// the cloud-sync was skipped (config off) — we treat that as a failure for
// the worker since flexsiebels needs the image_id to render the result.
type Outcome struct {
ImageID string
Err error
}
// Queue is the persistence layer for the imagen.jobs table. Implementations
// must be safe for serialised single-worker use (concurrent claim across
// multiple worker processes is out of scope for v1 — the FOR UPDATE SKIP
// LOCKED clause in the pgx claim query covers it cheaply anyway).
type Queue interface {
// ClaimNextPending atomically marks the oldest pending row 'running' and
// returns it. Returns (nil, nil) when the queue is empty.
ClaimNextPending(ctx context.Context) (*Job, error)
// MarkDone records success: status='done', image_id, completed_at=now().
MarkDone(ctx context.Context, jobID, imageID string) error
// MarkFailed records failure: status='failed', error=msg, completed_at=now().
MarkFailed(ctx context.Context, jobID, errMsg string) error
// WaitForJob blocks until either a NOTIFY arrives on imagen_jobs, the
// timeout expires, or ctx is cancelled. Returns nil on notification or
// timeout; returns ctx.Err() on cancellation. Transient connection errors
// are returned so the caller can decide to reconnect.
WaitForJob(ctx context.Context, timeout time.Duration) error
// ResetStaleRunning marks any rows stuck in 'running' (e.g. left over
// from a crash before this process started) back to 'pending'. Called
// once at worker startup so the cold-start safety poll can pick them up.
ResetStaleRunning(ctx context.Context) error
}
// Pipeline runs one generation and reports back the imagen.images.id (or an
// error). The implementation owns backend dispatch, prompt enrichment, disk
// write, and cloud-sync; the worker only orchestrates queue state.
type Pipeline interface {
Run(ctx context.Context, job Job) Outcome
}
// Config is the runtime knob set for the worker loop.
type Config struct {
// PollInterval is the safety-poll cadence between LISTEN wakeups. Picking
// this too low wastes DB roundtrips; too high lets a dropped NOTIFY
// stall the queue. 5s is the spec'd default.
PollInterval time.Duration
// JobTimeout caps any single Pipeline.Run. A backend hang shouldn't
// freeze the queue forever.
JobTimeout time.Duration
// Logger receives one-line status events. nil means silent.
Logger func(format string, args ...any)
}
// Worker is the orchestration loop. It is not reusable across Run calls.
type Worker struct {
q Queue
p Pipeline
cfg Config
// processingMu guards the in-flight job so SIGTERM-triggered shutdown
// waits for it to complete before returning.
processingMu sync.Mutex
}
// New constructs a Worker.
func New(q Queue, p Pipeline, cfg Config) *Worker {
if cfg.PollInterval <= 0 {
cfg.PollInterval = 5 * time.Second
}
if cfg.JobTimeout <= 0 {
cfg.JobTimeout = 5 * time.Minute
}
return &Worker{q: q, p: p, cfg: cfg}
}
// Run drives the consume loop until ctx is cancelled or a fatal queue error
// (e.g. unrecoverable DB drop) is returned. A LISTEN wait can fail with a
// transient transport error; the worker logs and continues so a temporary
// network blip doesn't take it down.
func (w *Worker) Run(ctx context.Context) error {
if err := w.q.ResetStaleRunning(ctx); err != nil {
w.log("worker: reset stale running rows: %v", err)
// Don't return — a stale row will eventually be visible to the poll
// path once flexsiebels gives up and resubmits, and we'd rather keep
// serving fresh jobs than crash here.
}
for {
if err := ctx.Err(); err != nil {
return nil
}
// Drain the queue: claim and process until empty.
if err := w.drain(ctx); err != nil && !errors.Is(err, context.Canceled) {
w.log("worker: drain: %v", err)
}
if err := ctx.Err(); err != nil {
return nil
}
// Wait for the next wake. WaitForJob covers both LISTEN and the
// timeout-based poll fallback; either returns nil and we loop.
if err := w.q.WaitForJob(ctx, w.cfg.PollInterval); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
w.log("worker: wait: %v (continuing)", err)
// Pace the retries so a totally-broken DB doesn't busy-spin.
select {
case <-ctx.Done():
return nil
case <-time.After(w.cfg.PollInterval):
}
}
}
}
// drain claims and processes every currently-pending job. The job-scoped
// context is derived from context.Background() so that a SIGTERM mid-job
// still lets the pipeline finish — that's the "no half-state on shutdown"
// guarantee the issue calls for.
func (w *Worker) drain(ctx context.Context) error {
for {
if err := ctx.Err(); err != nil {
return err
}
job, err := w.q.ClaimNextPending(ctx)
if err != nil {
return fmt.Errorf("claim: %w", err)
}
if job == nil {
return nil
}
w.processOne(*job)
}
}
// processOne runs the pipeline for one already-claimed job and writes the
// outcome back to the queue. The job context is independent of the outer
// ctx so an in-flight job can finish even after SIGTERM.
func (w *Worker) processOne(job Job) {
w.processingMu.Lock()
defer w.processingMu.Unlock()
w.log("worker: processing job %s backend=%s", job.ID, job.Backend)
jobCtx, cancel := context.WithTimeout(context.Background(), w.cfg.JobTimeout)
defer cancel()
out := w.p.Run(jobCtx, job)
// Status-update uses Background ctx with a short timeout — we must
// always be able to record the outcome, otherwise the row sits in
// 'running' forever.
updCtx, updCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer updCancel()
if out.Err != nil {
w.log("worker: job %s failed: %v", job.ID, out.Err)
if err := w.q.MarkFailed(updCtx, job.ID, out.Err.Error()); err != nil {
w.log("worker: mark failed for %s: %v", job.ID, err)
}
return
}
if out.ImageID == "" {
// Pipeline reported success but no imagen.images row — treat as
// failure because flexsiebels has nothing to link.
const msg = "pipeline did not return an imagen.images id (cloud sync misconfigured?)"
w.log("worker: job %s: %s", job.ID, msg)
if err := w.q.MarkFailed(updCtx, job.ID, msg); err != nil {
w.log("worker: mark failed for %s: %v", job.ID, err)
}
return
}
if err := w.q.MarkDone(updCtx, job.ID, out.ImageID); err != nil {
w.log("worker: mark done for %s: %v", job.ID, err)
return
}
w.log("worker: job %s done image_id=%s", job.ID, out.ImageID)
}
func (w *Worker) log(format string, args ...any) {
if w.cfg.Logger != nil {
w.cfg.Logger(format, args...)
}
}

View File

@@ -0,0 +1,332 @@
package worker
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
)
// fakeQueue is a hand-rolled in-memory queue that mirrors the contract of a
// real Postgres-backed implementation: ClaimNextPending atomically takes one
// pending row and flips its status to "running", MarkDone/MarkFailed are
// idempotent terminal transitions, WaitForJob blocks until notified or until
// the timeout elapses.
type fakeQueue struct {
mu sync.Mutex
pending []Job
state map[string]string // jobID -> status
last map[string]string // jobID -> error msg or image_id
notify chan struct{}
claimErr error
doneErr error
failErr error
resetErr error
claimed int
done int
failed int
resets int
}
func newFakeQueue(jobs ...Job) *fakeQueue {
q := &fakeQueue{
state: make(map[string]string),
last: make(map[string]string),
notify: make(chan struct{}, 16),
}
for _, j := range jobs {
q.pending = append(q.pending, j)
q.state[j.ID] = "pending"
}
return q
}
func (q *fakeQueue) ClaimNextPending(ctx context.Context) (*Job, error) {
q.mu.Lock()
defer q.mu.Unlock()
if q.claimErr != nil {
return nil, q.claimErr
}
if len(q.pending) == 0 {
return nil, nil
}
j := q.pending[0]
q.pending = q.pending[1:]
q.state[j.ID] = "running"
q.claimed++
return &j, nil
}
func (q *fakeQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.doneErr != nil {
return q.doneErr
}
q.state[jobID] = "done"
q.last[jobID] = imageID
q.done++
return nil
}
func (q *fakeQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.failErr != nil {
return q.failErr
}
q.state[jobID] = "failed"
q.last[jobID] = msg
q.failed++
return nil
}
func (q *fakeQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-q.notify:
return nil
case <-time.After(timeout):
return nil
}
}
func (q *fakeQueue) ResetStaleRunning(ctx context.Context) error {
q.mu.Lock()
defer q.mu.Unlock()
q.resets++
return q.resetErr
}
// pingNotify simulates an INSERT-trigger NOTIFY by waking WaitForJob.
func (q *fakeQueue) pingNotify() {
select {
case q.notify <- struct{}{}:
default:
}
}
// stub pipeline.
type fakePipeline struct {
mu sync.Mutex
results map[string]Outcome // by job.ID; "" key = default outcome
calls int
delay time.Duration
lastJob Job
}
func (p *fakePipeline) Run(ctx context.Context, job Job) Outcome {
p.mu.Lock()
p.calls++
p.lastJob = job
delay := p.delay
out, ok := p.results[job.ID]
if !ok {
out = p.results[""]
}
p.mu.Unlock()
if delay > 0 {
select {
case <-ctx.Done():
return Outcome{Err: ctx.Err()}
case <-time.After(delay):
}
}
return out
}
func TestWorker_DonePath(t *testing.T) {
q := newFakeQueue(
Job{ID: "j1", Prompt: "a", Backend: "mock"},
)
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img-1"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(80 * time.Millisecond)
cancel()
}()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if got := q.state["j1"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
if got := q.last["j1"]; got != "img-1" {
t.Fatalf("image_id=%q want img-1", got)
}
if q.done != 1 || q.failed != 0 {
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
}
if p.calls != 1 {
t.Fatalf("pipeline calls=%d want 1", p.calls)
}
if q.resets != 1 {
t.Fatalf("ResetStaleRunning calls=%d want 1", q.resets)
}
}
func TestWorker_FailedPath_RecordsErrorText(t *testing.T) {
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
p := &fakePipeline{results: map[string]Outcome{"j1": {Err: errors.New("backend unreachable")}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := q.state["j1"]; got != "failed" {
t.Fatalf("state=%q want failed", got)
}
if got := q.last["j1"]; got != "backend unreachable" {
t.Fatalf("error=%q want %q", got, "backend unreachable")
}
if q.done != 0 || q.failed != 1 {
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
}
}
func TestWorker_MissingImageID_TreatedAsFailure(t *testing.T) {
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
// Outcome has neither Err nor ImageID — pipeline silently swallowed
// cloud-sync. flexsiebels needs the image_id; without it, fail the job.
p := &fakePipeline{results: map[string]Outcome{"j1": {}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := q.state["j1"]; got != "failed" {
t.Fatalf("state=%q want failed", got)
}
if q.last["j1"] == "" {
t.Fatalf("expected non-empty error explanation for missing image_id")
}
}
func TestWorker_DrainsMultipleBeforeWaiting(t *testing.T) {
q := newFakeQueue(
Job{ID: "j1", Backend: "mock"},
Job{ID: "j2", Backend: "mock"},
Job{ID: "j3", Backend: "mock"},
)
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 200 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(60 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
for _, id := range []string{"j1", "j2", "j3"} {
if got := q.state[id]; got != "done" {
t.Fatalf("%s state=%q want done", id, got)
}
}
if q.done != 3 {
t.Fatalf("done=%d want 3", q.done)
}
}
func TestWorker_NotifyWakesEarlierThanPoll(t *testing.T) {
q := newFakeQueue()
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
// Set poll interval high so a working LISTEN is required to see the job
// promptly. Without NOTIFY plumbing this test would time out the worker
// before drain ever runs.
w := New(q, p, Config{PollInterval: 5 * time.Second, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
_ = w.Run(ctx)
close(done)
}()
// Append a job and ping the wake channel.
q.mu.Lock()
q.pending = append(q.pending, Job{ID: "late", Backend: "mock"})
q.state["late"] = "pending"
q.mu.Unlock()
q.pingNotify()
// Give the worker a beat to claim + process.
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
q.mu.Lock()
s := q.state["late"]
q.mu.Unlock()
if s == "done" {
cancel()
<-done
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("worker did not pick up the late job within the 500ms window — NOTIFY wake-up path is broken")
}
func TestWorker_HonoursContextCancellation(t *testing.T) {
q := newFakeQueue()
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
start := time.Now()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if dur := time.Since(start); dur > 200*time.Millisecond {
t.Fatalf("worker did not exit promptly on ctx cancel: %v", dur)
}
}
func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
q := newFakeQueue(Job{ID: "long", Backend: "mock"})
p := &fakePipeline{
results: map[string]Outcome{"long": {ImageID: "img-long"}},
delay: 120 * time.Millisecond,
}
// Short JobTimeout would also kill the in-flight job; give it enough
// budget so the test exercises the shutdown-during-job path.
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: 5 * time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() {
// Let the job start, then cancel mid-flight.
time.Sleep(30 * time.Millisecond)
cancel()
}()
_ = w.Run(ctx)
if got := q.state["long"]; got != "done" {
t.Fatalf("state=%q want done (in-flight job should finish even on shutdown)", got)
}
}
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
// First claim returns an error; the loop should log and try again on the
// next wake — it must not propagate the error and exit.
q := newFakeQueue(Job{ID: "j1", Backend: "mock"})
q.claimErr = fmt.Errorf("transient: connection reset")
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 20 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
// Heal the claim error after a beat so the second drain succeeds.
go func() {
time.Sleep(40 * time.Millisecond)
q.mu.Lock()
q.claimErr = nil
q.mu.Unlock()
}()
go func() {
time.Sleep(200 * time.Millisecond)
cancel()
}()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run returned: %v (transient claim errors should not kill the loop)", err)
}
if got := q.state["j1"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
}

View File

@@ -0,0 +1,22 @@
# Environment for the imagen-worker.service systemd unit.
# Copy to ~/.dotfiles/.env.imagen-worker and fill in real values.
# Never commit the populated file — it carries the Supabase service-role key.
# Direct Postgres DSN for LISTEN/NOTIFY + imagen.jobs UPDATE statements.
# PostgREST cannot LISTEN, so the worker connects to Postgres directly.
# Host + port + password come from the msupabase compose env on mlake.
IMAGEN_WORKER_DATABASE_URL=postgres://postgres:CHANGE_ME@100.99.98.201:6789/postgres?sslmode=disable
# PostgREST endpoint for the imagen.images cloud-sync writer (same as
# `imagen generate`'s cloud-sync code path).
SUPABASE_URL=https://supa.flexsiebels.de
SUPABASE_SERVICE_KEY=CHANGE_ME
# Default owner_user_id. Per-job owner from the imagen.jobs row overrides
# this, so it's only used as a fallback when a job arrives with a NULL
# owner_user_id — which the schema disallows. Keep it set for safety.
IMAGEN_OWNER_USER_ID=ac6c9501-3757-4a6d-8b97-2cff4288382b
# Optional: REPLICATE_API_TOKEN if any imagen.jobs.backend may resolve to
# a Replicate adapter instance.
# REPLICATE_API_TOKEN=CHANGE_ME

View File

@@ -0,0 +1,19 @@
[Unit]
Description=ImaGen worker (consumes imagen.jobs queue)
Documentation=https://mgit.msbls.de/m/ImaGen/issues/8
Wants=network-online.target
After=network-online.target
[Service]
Type=simple
ExecStart=%h/dev/ImaGen/bin/imagen worker
WorkingDirectory=%h/dev/ImaGen
EnvironmentFile=%h/.dotfiles/.env.imagen-worker
Restart=on-failure
RestartSec=5
# Give the worker time to finish an in-flight generation on shutdown
# (FLUX dev up to ~30s, plus the cloud-sync write-back).
TimeoutStopSec=60
[Install]
WantedBy=default.target