Merge mai/hermes/issue-9-imagen-9-imagen: imagen.series + series_id propagation (#9)
This commit is contained in:
@@ -132,7 +132,7 @@ func runGenerate(ctx context.Context, args []string) error {
|
|||||||
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
|
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", paths, in, res, w, h); err != nil {
|
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", "", paths, in, res, w, h); err != nil {
|
||||||
// cloud-sync failures are warnings — the image already wrote.
|
// cloud-sync failures are warnings — the image already wrote.
|
||||||
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
|
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
|
||||||
} else if result != nil && result.ImageID != "" {
|
} else if result != nil && result.ImageID != "" {
|
||||||
@@ -173,7 +173,10 @@ func resolveCloudSyncMode(cfg *config.Config, noCloudFlag bool, env string) (str
|
|||||||
// that need the imagen.images.id (e.g. the worker linking a job row) can pick
|
// that need the imagen.images.id (e.g. the worker linking a job row) can pick
|
||||||
// it up. ownerOverride, when non-empty, wins over config + env — the worker
|
// it up. ownerOverride, when non-empty, wins over config + env — the worker
|
||||||
// passes the job row's owner_user_id so each job is attributed correctly.
|
// passes the job row's owner_user_id so each job is attributed correctly.
|
||||||
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
|
// seriesID, when non-empty, lands on imagen.images.series_id so the
|
||||||
|
// list-page query (`WHERE series_id IS NULL`) hides series members from
|
||||||
|
// the flat grid; empty means solo run.
|
||||||
|
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride, seriesID string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
|
||||||
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
|
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -262,6 +265,7 @@ func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, owner
|
|||||||
LatencyMs: latency,
|
LatencyMs: latency,
|
||||||
CostUSDEstimate: cost,
|
CostUSDEstimate: cost,
|
||||||
Sidecar: sidecar,
|
Sidecar: sidecar,
|
||||||
|
SeriesID: seriesID,
|
||||||
}
|
}
|
||||||
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -112,6 +112,9 @@ func (q *pgxQueue) Close() {
|
|||||||
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
|
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
|
||||||
// process — out of scope for v1 but cheap insurance.
|
// process — out of scope for v1 but cheap insurance.
|
||||||
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
|
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
|
||||||
|
// series_id is nullable on imagen.jobs (solo run when NULL); cast to text
|
||||||
|
// with COALESCE so pgx scans into a plain Go string. Empty string =
|
||||||
|
// solo run; the pipeline skips series propagation in that case.
|
||||||
const stmt = `
|
const stmt = `
|
||||||
UPDATE imagen.jobs
|
UPDATE imagen.jobs
|
||||||
SET status='running', started_at=now()
|
SET status='running', started_at=now()
|
||||||
@@ -126,11 +129,13 @@ func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
|
|||||||
COALESCE(model,''),
|
COALESCE(model,''),
|
||||||
COALESCE(width, 0), COALESCE(height, 0),
|
COALESCE(width, 0), COALESCE(height, 0),
|
||||||
COALESCE(steps, 0), COALESCE(seed, 0),
|
COALESCE(steps, 0), COALESCE(seed, 0),
|
||||||
COALESCE(style,'')`
|
COALESCE(style,''),
|
||||||
|
COALESCE(series_id::text, '')`
|
||||||
var j worker.Job
|
var j worker.Job
|
||||||
err := q.conn.QueryRow(ctx, stmt).Scan(
|
err := q.conn.QueryRow(ctx, stmt).Scan(
|
||||||
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
|
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
|
||||||
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
|
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
|
||||||
|
&j.SeriesID,
|
||||||
)
|
)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -258,7 +263,7 @@ func (p *workerPipeline) Run(ctx context.Context, job worker.Job) worker.Outcome
|
|||||||
// config, the worker can't serve flexsiebels at all.
|
// config, the worker can't serve flexsiebels at all.
|
||||||
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
|
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
|
||||||
}
|
}
|
||||||
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
|
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, job.SeriesID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
|
||||||
if syncErr != nil {
|
if syncErr != nil {
|
||||||
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
|
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,6 +100,12 @@ type SyncRequest struct {
|
|||||||
LatencyMs int
|
LatencyMs int
|
||||||
CostUSDEstimate *float64
|
CostUSDEstimate *float64
|
||||||
Sidecar map[string]any
|
Sidecar map[string]any
|
||||||
|
|
||||||
|
// SeriesID is the parent imagen.series row when this image is one of
|
||||||
|
// N tries in a batch. Empty means a solo run — the column stays NULL,
|
||||||
|
// which keeps the row visible on the main list-page query
|
||||||
|
// (`WHERE series_id IS NULL`).
|
||||||
|
SeriesID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncResult tells the caller what landed where.
|
// SyncResult tells the caller what landed where.
|
||||||
@@ -195,6 +201,9 @@ func (s *Sink) insertRow(ctx context.Context, storagePath string, req SyncReques
|
|||||||
if len(req.Sidecar) > 0 {
|
if len(req.Sidecar) > 0 {
|
||||||
row["sidecar"] = req.Sidecar
|
row["sidecar"] = req.Sidecar
|
||||||
}
|
}
|
||||||
|
if req.SeriesID != "" {
|
||||||
|
row["series_id"] = req.SeriesID
|
||||||
|
}
|
||||||
|
|
||||||
body, err := json.Marshal(row)
|
body, err := json.Marshal(row)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -305,6 +305,54 @@ func TestSyncDBFailureSurfacesPathOnError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSyncWritesSeriesID is the second half of the ImaGen#9 propagation
|
||||||
|
// contract: when SeriesID is non-empty, the POST body to imagen.images
|
||||||
|
// carries `series_id`. When empty, the key is omitted entirely so the
|
||||||
|
// row's series_id stays NULL (solo-run path, list-page query
|
||||||
|
// `WHERE series_id IS NULL` keeps showing it).
|
||||||
|
func TestSyncWritesSeriesID(t *testing.T) {
|
||||||
|
const seriesID = "22222222-2222-2222-2222-222222222222"
|
||||||
|
f := newFakeSupabase(t)
|
||||||
|
s := newSink(f.server)
|
||||||
|
|
||||||
|
_, err := s.Sync(context.Background(), SyncRequest{
|
||||||
|
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||||
|
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||||
|
SeriesID: seriesID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Sync: %v", err)
|
||||||
|
}
|
||||||
|
var row map[string]any
|
||||||
|
if err := json.Unmarshal(f.insertBody, &row); err != nil {
|
||||||
|
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
|
||||||
|
}
|
||||||
|
if row["series_id"] != seriesID {
|
||||||
|
t.Fatalf("row.series_id = %v want %q", row["series_id"], seriesID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncOmitsSeriesIDWhenEmpty(t *testing.T) {
|
||||||
|
f := newFakeSupabase(t)
|
||||||
|
s := newSink(f.server)
|
||||||
|
|
||||||
|
_, err := s.Sync(context.Background(), SyncRequest{
|
||||||
|
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||||
|
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||||
|
// SeriesID intentionally empty.
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Sync: %v", err)
|
||||||
|
}
|
||||||
|
var row map[string]any
|
||||||
|
if err := json.Unmarshal(f.insertBody, &row); err != nil {
|
||||||
|
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
|
||||||
|
}
|
||||||
|
if _, present := row["series_id"]; present {
|
||||||
|
t.Fatalf("solo run should omit series_id from POST body, got %v", row["series_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPathEscape(t *testing.T) {
|
func TestPathEscape(t *testing.T) {
|
||||||
cases := map[string]string{
|
cases := map[string]string{
|
||||||
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",
|
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ type Job struct {
|
|||||||
Steps int
|
Steps int
|
||||||
Seed int64
|
Seed int64
|
||||||
Style string
|
Style string
|
||||||
|
// SeriesID is the parent imagen.series row when this job is one of N
|
||||||
|
// tries in a batch. Empty means a solo run — the pipeline must not
|
||||||
|
// propagate a series_id onto the resulting imagen.images row.
|
||||||
|
SeriesID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Outcome is what the pipeline reports back per job. ImageID is the
|
// Outcome is what the pipeline reports back per job. ImageID is the
|
||||||
|
|||||||
@@ -303,6 +303,50 @@ func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestWorker_PropagatesSeriesIDToPipeline verifies the worker hands the
|
||||||
|
// Job's SeriesID through to the pipeline unchanged. The pipeline owns the
|
||||||
|
// cloud-sync side of the propagation (cloud.SyncRequest.SeriesID lands on
|
||||||
|
// imagen.images.series_id) — see cloud_test.go for that half — so the
|
||||||
|
// worker contract is simply: don't drop or rewrite SeriesID between
|
||||||
|
// claim and Run.
|
||||||
|
func TestWorker_PropagatesSeriesIDToPipeline(t *testing.T) {
|
||||||
|
const seriesID = "11111111-1111-1111-1111-111111111111"
|
||||||
|
q := newFakeQueue(Job{
|
||||||
|
ID: "j-series",
|
||||||
|
Prompt: "p",
|
||||||
|
Backend: "mock",
|
||||||
|
SeriesID: seriesID,
|
||||||
|
})
|
||||||
|
p := &fakePipeline{results: map[string]Outcome{"j-series": {ImageID: "img-series"}}}
|
||||||
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
||||||
|
if err := w.Run(ctx); err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
if got := p.lastJob.SeriesID; got != seriesID {
|
||||||
|
t.Fatalf("pipeline saw SeriesID=%q want %q", got, seriesID)
|
||||||
|
}
|
||||||
|
if got := q.state["j-series"]; got != "done" {
|
||||||
|
t.Fatalf("state=%q want done", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWorker_SoloJobLeavesSeriesIDEmpty is the negative case — a job
|
||||||
|
// claimed with no series row keeps the field empty all the way to the
|
||||||
|
// pipeline so cloud-sync writes NULL into imagen.images.series_id.
|
||||||
|
func TestWorker_SoloJobLeavesSeriesIDEmpty(t *testing.T) {
|
||||||
|
q := newFakeQueue(Job{ID: "j-solo", Prompt: "p", Backend: "mock"})
|
||||||
|
p := &fakePipeline{results: map[string]Outcome{"j-solo": {ImageID: "img-solo"}}}
|
||||||
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
||||||
|
_ = w.Run(ctx)
|
||||||
|
if got := p.lastJob.SeriesID; got != "" {
|
||||||
|
t.Fatalf("solo job pipeline.lastJob.SeriesID=%q want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
|
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
|
||||||
// First claim returns an error; the loop should log and try again on the
|
// First claim returns an error; the loop should log and try again on the
|
||||||
// next wake — it must not propagate the error and exit.
|
// next wake — it must not propagate the error and exit.
|
||||||
|
|||||||
Reference in New Issue
Block a user