package worker import ( "context" "errors" "fmt" "sync" "testing" "time" ) // fakeQueue is a hand-rolled in-memory queue that mirrors the contract of a // real Postgres-backed implementation: ClaimNextPending atomically takes one // pending row and flips its status to "running", MarkDone/MarkFailed are // idempotent terminal transitions, WaitForJob blocks until notified or until // the timeout elapses. type fakeQueue struct { mu sync.Mutex pending []Job state map[string]string // jobID -> status last map[string]string // jobID -> error msg or image_id notify chan struct{} claimErr error doneErr error failErr error resetErr error claimed int done int failed int resets int } func newFakeQueue(jobs ...Job) *fakeQueue { q := &fakeQueue{ state: make(map[string]string), last: make(map[string]string), notify: make(chan struct{}, 16), } for _, j := range jobs { q.pending = append(q.pending, j) q.state[j.ID] = "pending" } return q } func (q *fakeQueue) ClaimNextPending(ctx context.Context) (*Job, error) { q.mu.Lock() defer q.mu.Unlock() if q.claimErr != nil { return nil, q.claimErr } if len(q.pending) == 0 { return nil, nil } j := q.pending[0] q.pending = q.pending[1:] q.state[j.ID] = "running" q.claimed++ return &j, nil } func (q *fakeQueue) MarkDone(ctx context.Context, jobID, imageID string) error { q.mu.Lock() defer q.mu.Unlock() if q.doneErr != nil { return q.doneErr } q.state[jobID] = "done" q.last[jobID] = imageID q.done++ return nil } func (q *fakeQueue) MarkFailed(ctx context.Context, jobID, msg string) error { q.mu.Lock() defer q.mu.Unlock() if q.failErr != nil { return q.failErr } q.state[jobID] = "failed" q.last[jobID] = msg q.failed++ return nil } func (q *fakeQueue) WaitForJob(ctx context.Context, timeout time.Duration) error { select { case <-ctx.Done(): return ctx.Err() case <-q.notify: return nil case <-time.After(timeout): return nil } } func (q *fakeQueue) ResetStaleRunning(ctx context.Context) error { q.mu.Lock() defer q.mu.Unlock() q.resets++ return q.resetErr } // pingNotify simulates an INSERT-trigger NOTIFY by waking WaitForJob. func (q *fakeQueue) pingNotify() { select { case q.notify <- struct{}{}: default: } } // stub pipeline. type fakePipeline struct { mu sync.Mutex results map[string]Outcome // by job.ID; "" key = default outcome calls int delay time.Duration lastJob Job } func (p *fakePipeline) Run(ctx context.Context, job Job) Outcome { p.mu.Lock() p.calls++ p.lastJob = job delay := p.delay out, ok := p.results[job.ID] if !ok { out = p.results[""] } p.mu.Unlock() if delay > 0 { select { case <-ctx.Done(): return Outcome{Err: ctx.Err()} case <-time.After(delay): } } return out } func TestWorker_DonePath(t *testing.T) { q := newFakeQueue( Job{ID: "j1", Prompt: "a", Backend: "mock"}, ) p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img-1"}}} w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(80 * time.Millisecond) cancel() }() if err := w.Run(ctx); err != nil { t.Fatalf("Run: %v", err) } if got := q.state["j1"]; got != "done" { t.Fatalf("state=%q want done", got) } if got := q.last["j1"]; got != "img-1" { t.Fatalf("image_id=%q want img-1", got) } if q.done != 1 || q.failed != 0 { t.Fatalf("counts: done=%d failed=%d", q.done, q.failed) } if p.calls != 1 { t.Fatalf("pipeline calls=%d want 1", p.calls) } if q.resets != 1 { t.Fatalf("ResetStaleRunning calls=%d want 1", q.resets) } } func TestWorker_FailedPath_RecordsErrorText(t *testing.T) { q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"}) p := &fakePipeline{results: map[string]Outcome{"j1": {Err: errors.New("backend unreachable")}}} w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(80 * time.Millisecond); cancel() }() _ = w.Run(ctx) if got := q.state["j1"]; got != "failed" { t.Fatalf("state=%q want failed", got) } if got := q.last["j1"]; got != "backend unreachable" { t.Fatalf("error=%q want %q", got, "backend unreachable") } if q.done != 0 || q.failed != 1 { t.Fatalf("counts: done=%d failed=%d", q.done, q.failed) } } func TestWorker_MissingImageID_TreatedAsFailure(t *testing.T) { q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"}) // Outcome has neither Err nor ImageID — pipeline silently swallowed // cloud-sync. flexsiebels needs the image_id; without it, fail the job. p := &fakePipeline{results: map[string]Outcome{"j1": {}}} w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(80 * time.Millisecond); cancel() }() _ = w.Run(ctx) if got := q.state["j1"]; got != "failed" { t.Fatalf("state=%q want failed", got) } if q.last["j1"] == "" { t.Fatalf("expected non-empty error explanation for missing image_id") } } func TestWorker_DrainsMultipleBeforeWaiting(t *testing.T) { q := newFakeQueue( Job{ID: "j1", Backend: "mock"}, Job{ID: "j2", Backend: "mock"}, Job{ID: "j3", Backend: "mock"}, ) p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}} w := New(q, p, Config{PollInterval: 200 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(60 * time.Millisecond); cancel() }() _ = w.Run(ctx) for _, id := range []string{"j1", "j2", "j3"} { if got := q.state[id]; got != "done" { t.Fatalf("%s state=%q want done", id, got) } } if q.done != 3 { t.Fatalf("done=%d want 3", q.done) } } func TestWorker_NotifyWakesEarlierThanPoll(t *testing.T) { q := newFakeQueue() p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}} // Set poll interval high so a working LISTEN is required to see the job // promptly. Without NOTIFY plumbing this test would time out the worker // before drain ever runs. w := New(q, p, Config{PollInterval: 5 * time.Second, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan struct{}) go func() { _ = w.Run(ctx) close(done) }() // Append a job and ping the wake channel. q.mu.Lock() q.pending = append(q.pending, Job{ID: "late", Backend: "mock"}) q.state["late"] = "pending" q.mu.Unlock() q.pingNotify() // Give the worker a beat to claim + process. deadline := time.Now().Add(500 * time.Millisecond) for time.Now().Before(deadline) { q.mu.Lock() s := q.state["late"] q.mu.Unlock() if s == "done" { cancel() <-done return } time.Sleep(5 * time.Millisecond) } t.Fatalf("worker did not pick up the late job within the 500ms window — NOTIFY wake-up path is broken") } func TestWorker_HonoursContextCancellation(t *testing.T) { q := newFakeQueue() p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}} w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) defer cancel() start := time.Now() if err := w.Run(ctx); err != nil { t.Fatalf("Run: %v", err) } if dur := time.Since(start); dur > 200*time.Millisecond { t.Fatalf("worker did not exit promptly on ctx cancel: %v", dur) } } func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) { q := newFakeQueue(Job{ID: "long", Backend: "mock"}) p := &fakePipeline{ results: map[string]Outcome{"long": {ImageID: "img-long"}}, delay: 120 * time.Millisecond, } // Short JobTimeout would also kill the in-flight job; give it enough // budget so the test exercises the shutdown-during-job path. w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: 5 * time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { // Let the job start, then cancel mid-flight. time.Sleep(30 * time.Millisecond) cancel() }() _ = w.Run(ctx) if got := q.state["long"]; got != "done" { t.Fatalf("state=%q want done (in-flight job should finish even on shutdown)", got) } } // TestWorker_PropagatesSeriesIDToPipeline verifies the worker hands the // Job's SeriesID through to the pipeline unchanged. The pipeline owns the // cloud-sync side of the propagation (cloud.SyncRequest.SeriesID lands on // imagen.images.series_id) — see cloud_test.go for that half — so the // worker contract is simply: don't drop or rewrite SeriesID between // claim and Run. func TestWorker_PropagatesSeriesIDToPipeline(t *testing.T) { const seriesID = "11111111-1111-1111-1111-111111111111" q := newFakeQueue(Job{ ID: "j-series", Prompt: "p", Backend: "mock", SeriesID: seriesID, }) p := &fakePipeline{results: map[string]Outcome{"j-series": {ImageID: "img-series"}}} w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(80 * time.Millisecond); cancel() }() if err := w.Run(ctx); err != nil { t.Fatalf("Run: %v", err) } if got := p.lastJob.SeriesID; got != seriesID { t.Fatalf("pipeline saw SeriesID=%q want %q", got, seriesID) } if got := q.state["j-series"]; got != "done" { t.Fatalf("state=%q want done", got) } } // TestWorker_SoloJobLeavesSeriesIDEmpty is the negative case — a job // claimed with no series row keeps the field empty all the way to the // pipeline so cloud-sync writes NULL into imagen.images.series_id. func TestWorker_SoloJobLeavesSeriesIDEmpty(t *testing.T) { q := newFakeQueue(Job{ID: "j-solo", Prompt: "p", Backend: "mock"}) p := &fakePipeline{results: map[string]Outcome{"j-solo": {ImageID: "img-solo"}}} w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(80 * time.Millisecond); cancel() }() _ = w.Run(ctx) if got := p.lastJob.SeriesID; got != "" { t.Fatalf("solo job pipeline.lastJob.SeriesID=%q want empty", got) } } func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) { // First claim returns an error; the loop should log and try again on the // next wake — it must not propagate the error and exit. q := newFakeQueue(Job{ID: "j1", Backend: "mock"}) q.claimErr = fmt.Errorf("transient: connection reset") p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img"}}} w := New(q, p, Config{PollInterval: 20 * time.Millisecond, JobTimeout: time.Second}) ctx, cancel := context.WithCancel(context.Background()) // Heal the claim error after a beat so the second drain succeeds. go func() { time.Sleep(40 * time.Millisecond) q.mu.Lock() q.claimErr = nil q.mu.Unlock() }() go func() { time.Sleep(200 * time.Millisecond) cancel() }() if err := w.Run(ctx); err != nil { t.Fatalf("Run returned: %v (transient claim errors should not kill the loop)", err) } if got := q.state["j1"]; got != "done" { t.Fatalf("state=%q want done", got) } }