Schema (applied via migration imagen_series_init): - imagen.series parent table (prompt + params + count CHECK 1..10 + selected_image_id) - imagen.jobs += series_id (FK) + series_idx - imagen.images += series_id (FK) - Owner-scoped RLS on series (SELECT/INSERT/UPDATE) + grants - Partial indexes WHERE series_id IS NOT NULL on both child tables Worker pipeline: - worker.Job += SeriesID, populated from imagen.jobs.series_id via the claim query. - cloud.SyncRequest += SeriesID; insertRow writes series_id when non-empty, omits the key when empty so solo runs leave the column NULL. - maybeCloudSync threads seriesID from job.SeriesID through to the cloud sink. generate.go (CLI) always passes "" — solo path unchanged. Tests: - worker: SeriesID propagates from Job to fakePipeline.lastJob unchanged, solo job keeps it empty. - cloud: SyncRequest.SeriesID lands as row.series_id in the POST body; empty SeriesID omits the key entirely. Refs ImaGen#9.
377 lines
11 KiB
Go
377 lines
11 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// fakeQueue is a hand-rolled in-memory queue that mirrors the contract of a
|
|
// real Postgres-backed implementation: ClaimNextPending atomically takes one
|
|
// pending row and flips its status to "running", MarkDone/MarkFailed are
|
|
// idempotent terminal transitions, WaitForJob blocks until notified or until
|
|
// the timeout elapses.
|
|
type fakeQueue struct {
|
|
mu sync.Mutex
|
|
pending []Job
|
|
state map[string]string // jobID -> status
|
|
last map[string]string // jobID -> error msg or image_id
|
|
notify chan struct{}
|
|
|
|
claimErr error
|
|
doneErr error
|
|
failErr error
|
|
resetErr error
|
|
|
|
claimed int
|
|
done int
|
|
failed int
|
|
resets int
|
|
}
|
|
|
|
func newFakeQueue(jobs ...Job) *fakeQueue {
|
|
q := &fakeQueue{
|
|
state: make(map[string]string),
|
|
last: make(map[string]string),
|
|
notify: make(chan struct{}, 16),
|
|
}
|
|
for _, j := range jobs {
|
|
q.pending = append(q.pending, j)
|
|
q.state[j.ID] = "pending"
|
|
}
|
|
return q
|
|
}
|
|
|
|
func (q *fakeQueue) ClaimNextPending(ctx context.Context) (*Job, error) {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
if q.claimErr != nil {
|
|
return nil, q.claimErr
|
|
}
|
|
if len(q.pending) == 0 {
|
|
return nil, nil
|
|
}
|
|
j := q.pending[0]
|
|
q.pending = q.pending[1:]
|
|
q.state[j.ID] = "running"
|
|
q.claimed++
|
|
return &j, nil
|
|
}
|
|
|
|
func (q *fakeQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
if q.doneErr != nil {
|
|
return q.doneErr
|
|
}
|
|
q.state[jobID] = "done"
|
|
q.last[jobID] = imageID
|
|
q.done++
|
|
return nil
|
|
}
|
|
|
|
func (q *fakeQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
if q.failErr != nil {
|
|
return q.failErr
|
|
}
|
|
q.state[jobID] = "failed"
|
|
q.last[jobID] = msg
|
|
q.failed++
|
|
return nil
|
|
}
|
|
|
|
func (q *fakeQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-q.notify:
|
|
return nil
|
|
case <-time.After(timeout):
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (q *fakeQueue) ResetStaleRunning(ctx context.Context) error {
|
|
q.mu.Lock()
|
|
defer q.mu.Unlock()
|
|
q.resets++
|
|
return q.resetErr
|
|
}
|
|
|
|
// pingNotify simulates an INSERT-trigger NOTIFY by waking WaitForJob.
|
|
func (q *fakeQueue) pingNotify() {
|
|
select {
|
|
case q.notify <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// stub pipeline.
|
|
type fakePipeline struct {
|
|
mu sync.Mutex
|
|
results map[string]Outcome // by job.ID; "" key = default outcome
|
|
calls int
|
|
delay time.Duration
|
|
lastJob Job
|
|
}
|
|
|
|
func (p *fakePipeline) Run(ctx context.Context, job Job) Outcome {
|
|
p.mu.Lock()
|
|
p.calls++
|
|
p.lastJob = job
|
|
delay := p.delay
|
|
out, ok := p.results[job.ID]
|
|
if !ok {
|
|
out = p.results[""]
|
|
}
|
|
p.mu.Unlock()
|
|
if delay > 0 {
|
|
select {
|
|
case <-ctx.Done():
|
|
return Outcome{Err: ctx.Err()}
|
|
case <-time.After(delay):
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func TestWorker_DonePath(t *testing.T) {
|
|
q := newFakeQueue(
|
|
Job{ID: "j1", Prompt: "a", Backend: "mock"},
|
|
)
|
|
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img-1"}}}
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
time.Sleep(80 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
if err := w.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if got := q.state["j1"]; got != "done" {
|
|
t.Fatalf("state=%q want done", got)
|
|
}
|
|
if got := q.last["j1"]; got != "img-1" {
|
|
t.Fatalf("image_id=%q want img-1", got)
|
|
}
|
|
if q.done != 1 || q.failed != 0 {
|
|
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
|
|
}
|
|
if p.calls != 1 {
|
|
t.Fatalf("pipeline calls=%d want 1", p.calls)
|
|
}
|
|
if q.resets != 1 {
|
|
t.Fatalf("ResetStaleRunning calls=%d want 1", q.resets)
|
|
}
|
|
}
|
|
|
|
func TestWorker_FailedPath_RecordsErrorText(t *testing.T) {
|
|
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
|
|
p := &fakePipeline{results: map[string]Outcome{"j1": {Err: errors.New("backend unreachable")}}}
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
|
_ = w.Run(ctx)
|
|
|
|
if got := q.state["j1"]; got != "failed" {
|
|
t.Fatalf("state=%q want failed", got)
|
|
}
|
|
if got := q.last["j1"]; got != "backend unreachable" {
|
|
t.Fatalf("error=%q want %q", got, "backend unreachable")
|
|
}
|
|
if q.done != 0 || q.failed != 1 {
|
|
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
|
|
}
|
|
}
|
|
|
|
func TestWorker_MissingImageID_TreatedAsFailure(t *testing.T) {
|
|
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
|
|
// Outcome has neither Err nor ImageID — pipeline silently swallowed
|
|
// cloud-sync. flexsiebels needs the image_id; without it, fail the job.
|
|
p := &fakePipeline{results: map[string]Outcome{"j1": {}}}
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
|
_ = w.Run(ctx)
|
|
|
|
if got := q.state["j1"]; got != "failed" {
|
|
t.Fatalf("state=%q want failed", got)
|
|
}
|
|
if q.last["j1"] == "" {
|
|
t.Fatalf("expected non-empty error explanation for missing image_id")
|
|
}
|
|
}
|
|
|
|
func TestWorker_DrainsMultipleBeforeWaiting(t *testing.T) {
|
|
q := newFakeQueue(
|
|
Job{ID: "j1", Backend: "mock"},
|
|
Job{ID: "j2", Backend: "mock"},
|
|
Job{ID: "j3", Backend: "mock"},
|
|
)
|
|
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
|
|
w := New(q, p, Config{PollInterval: 200 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() { time.Sleep(60 * time.Millisecond); cancel() }()
|
|
_ = w.Run(ctx)
|
|
|
|
for _, id := range []string{"j1", "j2", "j3"} {
|
|
if got := q.state[id]; got != "done" {
|
|
t.Fatalf("%s state=%q want done", id, got)
|
|
}
|
|
}
|
|
if q.done != 3 {
|
|
t.Fatalf("done=%d want 3", q.done)
|
|
}
|
|
}
|
|
|
|
func TestWorker_NotifyWakesEarlierThanPoll(t *testing.T) {
|
|
q := newFakeQueue()
|
|
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
|
|
// Set poll interval high so a working LISTEN is required to see the job
|
|
// promptly. Without NOTIFY plumbing this test would time out the worker
|
|
// before drain ever runs.
|
|
w := New(q, p, Config{PollInterval: 5 * time.Second, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
done := make(chan struct{})
|
|
go func() {
|
|
_ = w.Run(ctx)
|
|
close(done)
|
|
}()
|
|
// Append a job and ping the wake channel.
|
|
q.mu.Lock()
|
|
q.pending = append(q.pending, Job{ID: "late", Backend: "mock"})
|
|
q.state["late"] = "pending"
|
|
q.mu.Unlock()
|
|
q.pingNotify()
|
|
|
|
// Give the worker a beat to claim + process.
|
|
deadline := time.Now().Add(500 * time.Millisecond)
|
|
for time.Now().Before(deadline) {
|
|
q.mu.Lock()
|
|
s := q.state["late"]
|
|
q.mu.Unlock()
|
|
if s == "done" {
|
|
cancel()
|
|
<-done
|
|
return
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
t.Fatalf("worker did not pick up the late job within the 500ms window — NOTIFY wake-up path is broken")
|
|
}
|
|
|
|
func TestWorker_HonoursContextCancellation(t *testing.T) {
|
|
q := newFakeQueue()
|
|
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
|
defer cancel()
|
|
start := time.Now()
|
|
if err := w.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if dur := time.Since(start); dur > 200*time.Millisecond {
|
|
t.Fatalf("worker did not exit promptly on ctx cancel: %v", dur)
|
|
}
|
|
}
|
|
|
|
func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
|
|
q := newFakeQueue(Job{ID: "long", Backend: "mock"})
|
|
p := &fakePipeline{
|
|
results: map[string]Outcome{"long": {ImageID: "img-long"}},
|
|
delay: 120 * time.Millisecond,
|
|
}
|
|
// Short JobTimeout would also kill the in-flight job; give it enough
|
|
// budget so the test exercises the shutdown-during-job path.
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: 5 * time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
// Let the job start, then cancel mid-flight.
|
|
time.Sleep(30 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
_ = w.Run(ctx)
|
|
if got := q.state["long"]; got != "done" {
|
|
t.Fatalf("state=%q want done (in-flight job should finish even on shutdown)", got)
|
|
}
|
|
}
|
|
|
|
// TestWorker_PropagatesSeriesIDToPipeline verifies the worker hands the
|
|
// Job's SeriesID through to the pipeline unchanged. The pipeline owns the
|
|
// cloud-sync side of the propagation (cloud.SyncRequest.SeriesID lands on
|
|
// imagen.images.series_id) — see cloud_test.go for that half — so the
|
|
// worker contract is simply: don't drop or rewrite SeriesID between
|
|
// claim and Run.
|
|
func TestWorker_PropagatesSeriesIDToPipeline(t *testing.T) {
|
|
const seriesID = "11111111-1111-1111-1111-111111111111"
|
|
q := newFakeQueue(Job{
|
|
ID: "j-series",
|
|
Prompt: "p",
|
|
Backend: "mock",
|
|
SeriesID: seriesID,
|
|
})
|
|
p := &fakePipeline{results: map[string]Outcome{"j-series": {ImageID: "img-series"}}}
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
|
if err := w.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if got := p.lastJob.SeriesID; got != seriesID {
|
|
t.Fatalf("pipeline saw SeriesID=%q want %q", got, seriesID)
|
|
}
|
|
if got := q.state["j-series"]; got != "done" {
|
|
t.Fatalf("state=%q want done", got)
|
|
}
|
|
}
|
|
|
|
// TestWorker_SoloJobLeavesSeriesIDEmpty is the negative case — a job
|
|
// claimed with no series row keeps the field empty all the way to the
|
|
// pipeline so cloud-sync writes NULL into imagen.images.series_id.
|
|
func TestWorker_SoloJobLeavesSeriesIDEmpty(t *testing.T) {
|
|
q := newFakeQueue(Job{ID: "j-solo", Prompt: "p", Backend: "mock"})
|
|
p := &fakePipeline{results: map[string]Outcome{"j-solo": {ImageID: "img-solo"}}}
|
|
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
|
_ = w.Run(ctx)
|
|
if got := p.lastJob.SeriesID; got != "" {
|
|
t.Fatalf("solo job pipeline.lastJob.SeriesID=%q want empty", got)
|
|
}
|
|
}
|
|
|
|
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
|
|
// First claim returns an error; the loop should log and try again on the
|
|
// next wake — it must not propagate the error and exit.
|
|
q := newFakeQueue(Job{ID: "j1", Backend: "mock"})
|
|
q.claimErr = fmt.Errorf("transient: connection reset")
|
|
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img"}}}
|
|
w := New(q, p, Config{PollInterval: 20 * time.Millisecond, JobTimeout: time.Second})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Heal the claim error after a beat so the second drain succeeds.
|
|
go func() {
|
|
time.Sleep(40 * time.Millisecond)
|
|
q.mu.Lock()
|
|
q.claimErr = nil
|
|
q.mu.Unlock()
|
|
}()
|
|
go func() {
|
|
time.Sleep(200 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
if err := w.Run(ctx); err != nil {
|
|
t.Fatalf("Run returned: %v (transient claim errors should not kill the loop)", err)
|
|
}
|
|
if got := q.state["j1"]; got != "done" {
|
|
t.Fatalf("state=%q want done", got)
|
|
}
|
|
}
|