Merge mai/hermes/issue-9-imagen-9-imagen: imagen.series + series_id propagation (#9)

This commit is contained in:
mAi
2026-05-11 10:50:54 +02:00
6 changed files with 118 additions and 4 deletions

View File

@@ -132,7 +132,7 @@ func runGenerate(ctx context.Context, args []string) error {
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath) fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
} }
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", paths, in, res, w, h); err != nil { if result, err := maybeCloudSync(ctx, cfg, noCloud, "", "", paths, in, res, w, h); err != nil {
// cloud-sync failures are warnings — the image already wrote. // cloud-sync failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err) fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
} else if result != nil && result.ImageID != "" { } else if result != nil && result.ImageID != "" {
@@ -173,7 +173,10 @@ func resolveCloudSyncMode(cfg *config.Config, noCloudFlag bool, env string) (str
// that need the imagen.images.id (e.g. the worker linking a job row) can pick // that need the imagen.images.id (e.g. the worker linking a job row) can pick
// it up. ownerOverride, when non-empty, wins over config + env — the worker // it up. ownerOverride, when non-empty, wins over config + env — the worker
// passes the job row's owner_user_id so each job is attributed correctly. // passes the job row's owner_user_id so each job is attributed correctly.
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) { // seriesID, when non-empty, lands on imagen.images.series_id so the
// list-page query (`WHERE series_id IS NULL`) hides series members from
// the flat grid; empty means solo run.
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride, seriesID string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC")) mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -262,6 +265,7 @@ func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, owner
LatencyMs: latency, LatencyMs: latency,
CostUSDEstimate: cost, CostUSDEstimate: cost,
Sidecar: sidecar, Sidecar: sidecar,
SeriesID: seriesID,
} }
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second) syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel() defer cancel()

View File

@@ -112,6 +112,9 @@ func (q *pgxQueue) Close() {
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker // returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
// process — out of scope for v1 but cheap insurance. // process — out of scope for v1 but cheap insurance.
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) { func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
// series_id is nullable on imagen.jobs (solo run when NULL); cast to text
// with COALESCE so pgx scans into a plain Go string. Empty string =
// solo run; the pipeline skips series propagation in that case.
const stmt = ` const stmt = `
UPDATE imagen.jobs UPDATE imagen.jobs
SET status='running', started_at=now() SET status='running', started_at=now()
@@ -126,11 +129,13 @@ func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
COALESCE(model,''), COALESCE(model,''),
COALESCE(width, 0), COALESCE(height, 0), COALESCE(width, 0), COALESCE(height, 0),
COALESCE(steps, 0), COALESCE(seed, 0), COALESCE(steps, 0), COALESCE(seed, 0),
COALESCE(style,'')` COALESCE(style,''),
COALESCE(series_id::text, '')`
var j worker.Job var j worker.Job
err := q.conn.QueryRow(ctx, stmt).Scan( err := q.conn.QueryRow(ctx, stmt).Scan(
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend, &j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style, &j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
&j.SeriesID,
) )
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return nil, nil return nil, nil
@@ -258,7 +263,7 @@ func (p *workerPipeline) Run(ctx context.Context, job worker.Job) worker.Outcome
// config, the worker can't serve flexsiebels at all. // config, the worker can't serve flexsiebels at all.
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")} return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
} }
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height")) syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, job.SeriesID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
if syncErr != nil { if syncErr != nil {
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)} return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
} }

View File

@@ -100,6 +100,12 @@ type SyncRequest struct {
LatencyMs int LatencyMs int
CostUSDEstimate *float64 CostUSDEstimate *float64
Sidecar map[string]any Sidecar map[string]any
// SeriesID is the parent imagen.series row when this image is one of
// N tries in a batch. Empty means a solo run — the column stays NULL,
// which keeps the row visible on the main list-page query
// (`WHERE series_id IS NULL`).
SeriesID string
} }
// SyncResult tells the caller what landed where. // SyncResult tells the caller what landed where.
@@ -195,6 +201,9 @@ func (s *Sink) insertRow(ctx context.Context, storagePath string, req SyncReques
if len(req.Sidecar) > 0 { if len(req.Sidecar) > 0 {
row["sidecar"] = req.Sidecar row["sidecar"] = req.Sidecar
} }
if req.SeriesID != "" {
row["series_id"] = req.SeriesID
}
body, err := json.Marshal(row) body, err := json.Marshal(row)
if err != nil { if err != nil {

View File

@@ -305,6 +305,54 @@ func TestSyncDBFailureSurfacesPathOnError(t *testing.T) {
} }
} }
// TestSyncWritesSeriesID is the second half of the ImaGen#9 propagation
// contract: when SeriesID is non-empty, the POST body to imagen.images
// carries `series_id`. When empty, the key is omitted entirely so the
// row's series_id stays NULL (solo-run path, list-page query
// `WHERE series_id IS NULL` keeps showing it).
func TestSyncWritesSeriesID(t *testing.T) {
const seriesID = "22222222-2222-2222-2222-222222222222"
f := newFakeSupabase(t)
s := newSink(f.server)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
SeriesID: seriesID,
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
}
if row["series_id"] != seriesID {
t.Fatalf("row.series_id = %v want %q", row["series_id"], seriesID)
}
}
func TestSyncOmitsSeriesIDWhenEmpty(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
// SeriesID intentionally empty.
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
}
if _, present := row["series_id"]; present {
t.Fatalf("solo run should omit series_id from POST body, got %v", row["series_id"])
}
}
func TestPathEscape(t *testing.T) { func TestPathEscape(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png", "2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",

View File

@@ -30,6 +30,10 @@ type Job struct {
Steps int Steps int
Seed int64 Seed int64
Style string Style string
// SeriesID is the parent imagen.series row when this job is one of N
// tries in a batch. Empty means a solo run — the pipeline must not
// propagate a series_id onto the resulting imagen.images row.
SeriesID string
} }
// Outcome is what the pipeline reports back per job. ImageID is the // Outcome is what the pipeline reports back per job. ImageID is the

View File

@@ -303,6 +303,50 @@ func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
} }
} }
// TestWorker_PropagatesSeriesIDToPipeline verifies the worker hands the
// Job's SeriesID through to the pipeline unchanged. The pipeline owns the
// cloud-sync side of the propagation (cloud.SyncRequest.SeriesID lands on
// imagen.images.series_id) — see cloud_test.go for that half — so the
// worker contract is simply: don't drop or rewrite SeriesID between
// claim and Run.
func TestWorker_PropagatesSeriesIDToPipeline(t *testing.T) {
const seriesID = "11111111-1111-1111-1111-111111111111"
q := newFakeQueue(Job{
ID: "j-series",
Prompt: "p",
Backend: "mock",
SeriesID: seriesID,
})
p := &fakePipeline{results: map[string]Outcome{"j-series": {ImageID: "img-series"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if got := p.lastJob.SeriesID; got != seriesID {
t.Fatalf("pipeline saw SeriesID=%q want %q", got, seriesID)
}
if got := q.state["j-series"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
}
// TestWorker_SoloJobLeavesSeriesIDEmpty is the negative case — a job
// claimed with no series row keeps the field empty all the way to the
// pipeline so cloud-sync writes NULL into imagen.images.series_id.
func TestWorker_SoloJobLeavesSeriesIDEmpty(t *testing.T) {
q := newFakeQueue(Job{ID: "j-solo", Prompt: "p", Backend: "mock"})
p := &fakePipeline{results: map[string]Outcome{"j-solo": {ImageID: "img-solo"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := p.lastJob.SeriesID; got != "" {
t.Fatalf("solo job pipeline.lastJob.SeriesID=%q want empty", got)
}
}
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) { func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
// First claim returns an error; the loop should log and try again on the // First claim returns an error; the loop should log and try again on the
// next wake — it must not propagate the error and exit. // next wake — it must not propagate the error and exit.