package server import ( "context" "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "net/url" "strings" "sync/atomic" "testing" "time" "mgit.msbls.de/m/mGPUmanager/internal/config" "mgit.msbls.de/m/mGPUmanager/internal/gpu" "mgit.msbls.de/m/mGPUmanager/internal/registry" "mgit.msbls.de/m/mGPUmanager/internal/scheduler" ) // silentLogger discards everything; keeps test output clean. func silentLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) } // fakeMVoice spins up a httptest.Server that mimics the relevant mvoice // endpoints just enough to exercise the broker's proxy behaviour. func fakeMVoice(t *testing.T) (*httptest.Server, *atomic.Int32) { t.Helper() calls := &atomic.Int32{} mux := http.NewServeMux() mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"status":"ready","loaded":true,"gpu_resident_mib":2800}`)) }) mux.HandleFunc("POST /api/synthesize", func(w http.ResponseWriter, r *http.Request) { calls.Add(1) body, _ := io.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"audio_url":"/api/audio/test.wav","payload_echo":"` + strings.ReplaceAll(string(body), `"`, `\"`) + `"}`)) }) mux.HandleFunc("GET /api/audio/test.wav", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "audio/wav") _, _ = w.Write([]byte("RIFF....fake-wav-bytes")) }) return httptest.NewServer(mux), calls } func buildHarness(t *testing.T, upstream *httptest.Server) (*Server, *registry.Registry, *gpu.Poller) { t.Helper() cfg := &config.Config{ Listen: "127.0.0.1:0", GPU: config.GPU{TotalMiB: 16376, ReservedMiB: 1024, PollIntervalSeconds: 2}, Routing: map[config.EndpointKind]string{ config.KindTTS: "mvoice", }, AudioProxy: "mvoice", AudioPathPrefix: "/api/audio/", Consumers: map[string]*config.Consumer{ "mvoice": { URL: upstream.URL, Health: config.Route{Method: "GET", Path: "/api/health"}, Paths: map[config.EndpointKind]config.Route{ config.KindTTS: {Method: "POST", Path: "/api/synthesize"}, }, VRAMResidentMiB: 2800, MaxConcurrency: 1, Priority: 3, }, }, } reg := registry.New(cfg, silentLogger()) poller := gpu.NewPoller(time.Second, silentLogger()) sched := scheduler.NewPassthrough(reg) srv := New(cfg, reg, poller, sched, silentLogger()) return srv, reg, poller } func TestProxyTTSForwardsBodyAndReturnsConsumerResponse(t *testing.T) { upstream, calls := fakeMVoice(t) defer upstream.Close() srv, reg, _ := buildHarness(t, upstream) // Force-mark mvoice healthy so the gating in handleEndpoint passes. probeOnce(t, reg, upstream.URL+"/api/health", "mvoice") ts := httptest.NewServer(srv.Handler()) defer ts.Close() form := url.Values{} form.Set("text", "Hallo") resp, err := http.Post(ts.URL+"/v1/tts", "application/x-www-form-urlencoded", strings.NewReader(form.Encode())) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != 200 { body, _ := io.ReadAll(resp.Body) t.Fatalf("status %d: %s", resp.StatusCode, body) } if got := calls.Load(); got != 1 { t.Errorf("upstream calls = %d, want 1", got) } var payload struct { AudioURL string `json:"audio_url"` PayloadEcho string `json:"payload_echo"` } if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { t.Fatal(err) } if !strings.Contains(payload.PayloadEcho, "text=Hallo") { t.Errorf("upstream did not receive body: %q", payload.PayloadEcho) } if payload.AudioURL != "/api/audio/test.wav" { t.Errorf("AudioURL = %q", payload.AudioURL) } } func TestAudioProxyForwardsBytes(t *testing.T) { upstream, _ := fakeMVoice(t) defer upstream.Close() srv, _, _ := buildHarness(t, upstream) ts := httptest.NewServer(srv.Handler()) defer ts.Close() resp, err := http.Get(ts.URL + "/api/audio/test.wav") if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("status %d", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) if !strings.HasPrefix(string(body), "RIFF") { t.Errorf("audio body did not start with RIFF: %q", body[:8]) } } func TestUnhealthyConsumerReturns503(t *testing.T) { // Upstream that always 500s health probe. upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "boom", http.StatusInternalServerError) })) defer upstream.Close() srv, reg, _ := buildHarness(t, upstream) probeOnce(t, reg, upstream.URL+"/api/health", "mvoice") ts := httptest.NewServer(srv.Handler()) defer ts.Close() resp, err := http.Post(ts.URL+"/v1/tts", "application/json", strings.NewReader(`{}`)) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != 503 { t.Fatalf("status = %d, want 503", resp.StatusCode) } var env errorBody if err := json.NewDecoder(resp.Body).Decode(&env); err != nil { t.Fatal(err) } if env.Error != "consumer_unreachable" { t.Errorf("error code = %q", env.Error) } if !env.Retryable { t.Errorf("retryable should be true for unreachable consumer") } } func TestStatusReturnsConsumers(t *testing.T) { upstream, _ := fakeMVoice(t) defer upstream.Close() srv, reg, _ := buildHarness(t, upstream) probeOnce(t, reg, upstream.URL+"/api/health", "mvoice") ts := httptest.NewServer(srv.Handler()) defer ts.Close() resp, err := http.Get(ts.URL + "/v1/status") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if resp.StatusCode != 200 { t.Fatalf("status %d: %s", resp.StatusCode, body) } var sr statusResponse if err := json.Unmarshal(body, &sr); err != nil { t.Fatal(err) } if len(sr.Consumers) != 1 { t.Fatalf("consumers = %d", len(sr.Consumers)) } if sr.Consumers[0].Name != "mvoice" { t.Errorf("name = %q", sr.Consumers[0].Name) } if !sr.Consumers[0].Healthy { t.Errorf("expected healthy=true after successful probe") } if sr.GPU.TotalMiB != 16376 { t.Errorf("GPU.TotalMiB = %d", sr.GPU.TotalMiB) } } // probeOnce drives the registry to record one successful (or failed) health // probe synchronously so tests don't have to wait for the 5s loop. func probeOnce(t *testing.T, reg *registry.Registry, healthURL, name string) { t.Helper() req, err := http.NewRequestWithContext(context.Background(), "GET", healthURL, nil) if err != nil { t.Fatal(err) } client := &http.Client{Timeout: 2 * time.Second} resp, err := client.Do(req) if err != nil { reg.RecordProbeForTest(name, false, err.Error(), nil) return } defer resp.Body.Close() body, _ := io.ReadAll(io.LimitReader(resp.Body, 8192)) ok := resp.StatusCode < 400 errMsg := "" if !ok { errMsg = "status " + strings.TrimSpace(resp.Status) } reg.RecordProbeForTest(name, ok, errMsg, body) }