diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f07a089 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# Build artifacts +/bin/ + +# Worker session noise +.m/ +*.log + +# Go test/coverage +*.out +coverage.html + +# Editor cruft +*.swp +.idea/ +.vscode/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..370aaa3 --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +# mGPUmanager build + deploy targets. +# +# `make build` — compile the Go binary into ./bin/mgpumanager. +# `make test` — go test ./... +# `make run` — run locally against ./config/consumers.yaml. +# `make deploy` — rsync binary + config + systemd unit to mRock, +# reload systemd, restart the service. + +BIN := bin/mgpumanager +PKG := ./cmd/mgpumanager + +GO ?= go +HOST ?= mrock +REMOTE_DIR ?= /home/m/dev/mGPUmanager + +.PHONY: build test run deploy clean + +build: + mkdir -p bin + $(GO) build -trimpath -ldflags="-s -w" -o $(BIN) $(PKG) + +test: + $(GO) test ./... + +run: build + ./$(BIN) --config config/consumers.yaml --log-level debug + +deploy: build + rsync -a --mkpath $(BIN) $(HOST):$(REMOTE_DIR)/$(BIN) + rsync -a --mkpath config/consumers.yaml $(HOST):$(REMOTE_DIR)/config/consumers.yaml + rsync -a --mkpath systemd/mgpumanager.service $(HOST):$(REMOTE_DIR)/systemd/mgpumanager.service + ssh $(HOST) "sudo cp $(REMOTE_DIR)/systemd/mgpumanager.service /etc/systemd/system/mgpumanager.service && sudo systemctl daemon-reload && sudo systemctl enable mgpumanager.service && sudo systemctl restart mgpumanager.service && sudo systemctl status mgpumanager.service --no-pager -l" + +clean: + rm -rf bin diff --git a/README.md b/README.md index 80fb502..461a390 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,73 @@ # mGPUmanager -GPU-Inference-Control-Plane für mRock — Scheduler vor TTS/STT/LLM/Image-Gen mit globalem GPU-Lock + LRU-Eviction + einheitlicher /v1-Fassade. Konsumenten: mVoice, whisper-server, Ollama, ComfyUI/FLUX, später Furbotto. Go. \ No newline at end of file +GPU-Inference-Control-Plane für mRock — Scheduler vor TTS/STT/LLM/Image-Gen mit globalem GPU-Lock + LRU-Eviction + einheitlicher `/v1`-Fassade. Konsumenten: mVoice, whisper-server, Ollama, ComfyUI/FLUX, später Furbotto. Go. + +Full design: [`docs/design.md`](docs/design.md) — Bestandsaufnahme, 10-Alternativen-Survey, Eviction-Algorithmus, Migrationspfad. + +## Was es macht + +Auf `mrock:8770` sitzt ein Go-Daemon, der: + +- `/v1/tts`, `/v1/stt`, `/v1/llm`, `/v1/image` als einheitliche Konsumenten-Fassade exponiert, +- jede Anfrage durch einen globalen GPU-Scheduler schleust (seriell, Queue), +- bei VRAM-Druck LRU-Eviction über die deklarierten Coexistenz-Gruppen aus `config/consumers.yaml` fährt, +- in `/v1/status` Live-GPU-Belegung + Consumer-Health + Scheduler-Statistiken zeigt, +- niemals stille Fallbacks zurückgibt — Fehler kommen als strukturiertes `{error,message,consumer,retryable}`. + +## Konsumenten-Registry + +`config/consumers.yaml` deklariert pro Consumer: + +- `url`, `health.{method,path}` für Liveness-Probing +- `paths..{method,path}` — wie der Broker zu seinem TTS/STT/LLM/Image-Endpoint kommt +- `vram_resident_mib` — für die Scheduler-Mathe (Schritt 5) +- `unload.{method,path,body}` und optional `load.{method,path}` — wie der Broker den Consumer aus dem VRAM räumt / wieder hochfährt +- `can_coexist_with: [..]` — wer parallel resident sein darf +- `priority` (0=low, 4=urgent), `max_concurrency` + +## Build + Deploy + +```sh +make build # ./bin/mgpumanager +make test # go test ./... +make run # lokal gegen ./config/consumers.yaml +make deploy HOST=mrock # rsync + systemd reload + restart +``` + +Auf mRock läuft der Daemon als System-Unit (`/etc/systemd/system/mgpumanager.service`). + +## Endpoints + +| Verb | Pfad | Verhalten | +|---|---|---| +| POST | `/v1/tts` | Proxy zu `routing.tts`-Consumer (default: mvoice `/api/synthesize`) | +| POST | `/v1/stt` | Proxy zu `routing.stt`-Consumer (default: mvoice `/api/transcribe`) | +| POST | `/v1/llm` | Proxy zu `routing.llm`-Consumer (default: ollama `/api/generate`) | +| POST | `/v1/image` | Proxy zu `routing.image`-Consumer (default: comfyui `/prompt`) | +| GET | `/audio/*` | Proxy zu `audio_proxy`-Consumer (wa.sh fetcht generiertes Audio so) | +| GET | `/v1/status`| Live-Snapshot: GPU + Consumer-Health + Scheduler-Stats | +| GET | `/healthz` | Broker-Liveness (200 OK) | + +## Fehler-Schema + +Jeder Broker-eigene Fehler hat die Form: + +```json +{ + "error": "consumer_unreachable", + "message": "upstream mvoice last probe failed: connection refused", + "consumer": "mvoice", + "retryable": true +} +``` + +Codes: `consumer_unreachable`, `no_consumer`, `scheduler_error`, `bad_consumer_url`, `bad_request`. Pass-through-4xx/5xx vom Consumer landet unverändert beim Client. + +## Phase 1 Status (Issue #1) + +- ✅ Schritt 0 — ComfyUI persistent (`systemd: comfyui.service`) +- ✅ Schritt 1 — `mvoice /api/admin/{load,unload}` (mai/knuth/admin-load-unload @ mVoice) +- ✅ Schritt 2 — Routing-Façade + `/v1/status` (passthrough scheduler) +- ☐ Schritt 3 — wa.sh auf Broker umgestellt +- ☐ Schritt 4 — Queue + globaler GPU-Lock +- ☐ Schritt 5 — Coexistenz-Gruppen + LRU-Eviction diff --git a/cmd/mgpumanager/main.go b/cmd/mgpumanager/main.go new file mode 100644 index 0000000..e5d3478 --- /dev/null +++ b/cmd/mgpumanager/main.go @@ -0,0 +1,103 @@ +// mgpumanager is the GPU-Inference-Control-Plane for mRock. +// +// One Go binary that: +// 1. Loads consumers.yaml. +// 2. Probes every consumer's /health on a 5s cadence. +// 3. Polls nvidia-smi every 2s for live VRAM usage (used by Schritt 5 +// eviction). +// 4. Exposes /v1/{tts,stt,llm,image} as a thin proxy + /v1/status for +// observability. +// 5. Funnels every job through the Scheduler (passthrough today, queue + +// eviction in Schritt 4-5). +// +// All client routing happens through this daemon — no consumer is reached +// directly any more. wa.sh, ImaGen, m-CLI and Furbotto-Voice will all speak +// to :8770/v1/*. +package main + +import ( + "context" + "errors" + "flag" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "mgit.msbls.de/m/mGPUmanager/internal/config" + "mgit.msbls.de/m/mGPUmanager/internal/gpu" + "mgit.msbls.de/m/mGPUmanager/internal/registry" + "mgit.msbls.de/m/mGPUmanager/internal/scheduler" + "mgit.msbls.de/m/mGPUmanager/internal/server" +) + +func main() { + configPath := flag.String("config", "config/consumers.yaml", "path to consumers.yaml") + listenOverride := flag.String("listen", "", "override listen address from config") + logLevel := flag.String("log-level", "info", "log level: debug|info|warn|error") + flag.Parse() + + logger := newLogger(*logLevel) + + cfg, err := config.Load(*configPath) + if err != nil { + logger.Error("config load failed", "err", err, "path", *configPath) + os.Exit(1) + } + if *listenOverride != "" { + cfg.Listen = *listenOverride + } + + logger.Info("starting mGPUmanager", + "listen", cfg.Listen, + "consumers", len(cfg.Consumers), + "poll_interval", cfg.GPU.PollInterval(), + ) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + reg := registry.New(cfg, logger.With("component", "registry")) + gpuPoller := gpu.NewPoller(cfg.GPU.PollInterval(), logger.With("component", "gpu")) + sched := scheduler.NewPassthrough(reg) + + go reg.Run(ctx) + go gpuPoller.Run(ctx) + + srv := server.New(cfg, reg, gpuPoller, sched, logger.With("component", "server")) + httpSrv := &http.Server{ + Addr: cfg.Listen, + Handler: srv.Handler(), + ReadHeaderTimeout: 10 * time.Second, + } + + go func() { + <-ctx.Done() + shutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = httpSrv.Shutdown(shutCtx) + }() + + if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("listen failed", "err", err) + os.Exit(1) + } + logger.Info("shutdown complete") +} + +func newLogger(level string) *slog.Logger { + var lvl slog.Level + switch level { + case "debug": + lvl = slog.LevelDebug + case "warn": + lvl = slog.LevelWarn + case "error": + lvl = slog.LevelError + default: + lvl = slog.LevelInfo + } + return slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: lvl})) +} diff --git a/config/consumers.yaml b/config/consumers.yaml new file mode 100644 index 0000000..731851e --- /dev/null +++ b/config/consumers.yaml @@ -0,0 +1,88 @@ +listen: 127.0.0.1:8770 + +gpu: + total_mib: 16376 # RTX 4070 Ti SUPER + reserved_mib: 1024 # headroom for system/desktop + poll_interval_seconds: 2 + +routing: + tts: mvoice + stt: mvoice # whisper-server is alternative if explicitly requested + llm: ollama + image: comfyui + +# Audio download proxy: GET /audio/* forwards to this consumer. +audio_proxy: mvoice + +consumers: + mvoice: + url: http://localhost:8766 + health: + method: GET + path: /api/health + paths: + tts: + method: POST + path: /api/synthesize + stt: + method: POST + path: /api/transcribe + vram_resident_mib: 2800 + load: + method: POST + path: /api/admin/load + unload: + method: POST + path: /api/admin/unload + can_coexist_with: [whisper-server, ollama] + priority: 3 + max_concurrency: 1 + + whisper-server: + url: http://localhost:8178 + health: + method: GET + path: / + paths: + stt: + method: POST + path: /inference + vram_resident_mib: 2050 + # No HTTP unload; mGPUmanager evicts via systemd restart (Schritt 5). + systemd_unit: whisper-server.service + can_coexist_with: [mvoice, ollama] + priority: 2 + max_concurrency: 1 + + ollama: + url: http://localhost:11434 + health: + method: GET + path: /api/tags + paths: + llm: + method: POST + path: /api/generate + # Ollama runs its own LRU keep_alive; we don't track resident VRAM. + vram_managed: true + can_coexist_with: [mvoice, whisper-server] + priority: 2 + max_concurrency: 1 + + comfyui: + url: http://localhost:8188 + health: + method: GET + path: /system_stats + paths: + image: + method: POST + path: /prompt + vram_resident_mib: 13000 + unload: + method: POST + path: /api/free + body: '{"unload_models":true,"free_memory":true}' + can_coexist_with: [] + priority: 1 + max_concurrency: 1 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6ab526a --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module mgit.msbls.de/m/mGPUmanager + +go 1.25.5 + +require gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a62c313 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..41c55cc --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,165 @@ +// Package config loads the mGPUmanager consumer registry from YAML. +// +// The consumers.yaml file declares every GPU consumer (mvoice, whisper-server, +// ollama, comfyui), how to route the four logical endpoint kinds (tts, stt, +// llm, image) to a consumer, how to probe its health, and how to load/unload +// it from VRAM. The scheduler (Schritt 4–5) reads vram_resident_mib + +// can_coexist_with to drive eviction. +package config + +import ( + "fmt" + "net/url" + "os" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +// EndpointKind enumerates the four logical broker endpoints exposed on /v1/*. +type EndpointKind string + +const ( + KindTTS EndpointKind = "tts" + KindSTT EndpointKind = "stt" + KindLLM EndpointKind = "llm" + KindImage EndpointKind = "image" +) + +// AllKinds is the canonical ordering used by /v1/status and tests. +var AllKinds = []EndpointKind{KindTTS, KindSTT, KindLLM, KindImage} + +// Route describes an HTTP method + path on a consumer. +type Route struct { + Method string `yaml:"method"` + Path string `yaml:"path"` + // Body is an optional fixed request body for admin operations + // (e.g. ComfyUI's /api/free expects {"unload_models":true,"free_memory":true}). + Body string `yaml:"body,omitempty"` +} + +// Consumer describes a single GPU consumer behind the broker. +type Consumer struct { + URL string `yaml:"url"` + Health Route `yaml:"health"` + Paths map[EndpointKind]Route `yaml:"paths"` + VRAMResidentMiB int `yaml:"vram_resident_mib"` + VRAMManaged bool `yaml:"vram_managed"` // self-managed LRU (ollama) + Load *Route `yaml:"load,omitempty"` + Unload *Route `yaml:"unload,omitempty"` + SystemdUnit string `yaml:"systemd_unit,omitempty"` // fallback unload (whisper-server) + CanCoexistWith []string `yaml:"can_coexist_with"` + Priority int `yaml:"priority"` + MaxConcurrency int `yaml:"max_concurrency"` +} + +// GPU describes the host's GPU envelope. +type GPU struct { + TotalMiB int `yaml:"total_mib"` + ReservedMiB int `yaml:"reserved_mib"` + PollIntervalSeconds int `yaml:"poll_interval_seconds"` +} + +// PollInterval returns the GPU polling cadence as a Duration. Defaults to 2s. +func (g GPU) PollInterval() time.Duration { + if g.PollIntervalSeconds <= 0 { + return 2 * time.Second + } + return time.Duration(g.PollIntervalSeconds) * time.Second +} + +// AvailableMiB returns total VRAM minus the system-reserved headroom. +func (g GPU) AvailableMiB() int { + if g.TotalMiB <= 0 { + return 0 + } + avail := g.TotalMiB - g.ReservedMiB + if avail < 0 { + return 0 + } + return avail +} + +// Config is the parsed mGPUmanager configuration. +type Config struct { + Listen string `yaml:"listen"` + GPU GPU `yaml:"gpu"` + Routing map[EndpointKind]string `yaml:"routing"` + AudioProxy string `yaml:"audio_proxy"` + Consumers map[string]*Consumer `yaml:"consumers"` +} + +// Load reads and validates a consumers.yaml file from disk. +func Load(path string) (*Config, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read %s: %w", path, err) + } + var cfg Config + if err := yaml.Unmarshal(b, &cfg); err != nil { + return nil, fmt.Errorf("parse %s: %w", path, err) + } + if err := cfg.validate(); err != nil { + return nil, fmt.Errorf("validate %s: %w", path, err) + } + return &cfg, nil +} + +func (c *Config) validate() error { + if c.Listen == "" { + c.Listen = "127.0.0.1:8770" + } + if len(c.Consumers) == 0 { + return fmt.Errorf("no consumers declared") + } + for name, cons := range c.Consumers { + if cons.URL == "" { + return fmt.Errorf("consumer %q: url is required", name) + } + if _, err := url.Parse(cons.URL); err != nil { + return fmt.Errorf("consumer %q: invalid url %q: %w", name, cons.URL, err) + } + if cons.Health.Path == "" { + return fmt.Errorf("consumer %q: health.path is required", name) + } + if cons.Health.Method == "" { + cons.Health.Method = "GET" + } + cons.Health.Method = strings.ToUpper(cons.Health.Method) + for kind, route := range cons.Paths { + if route.Path == "" { + return fmt.Errorf("consumer %q: paths.%s.path is required", name, kind) + } + if route.Method == "" { + route.Method = "POST" + } + route.Method = strings.ToUpper(route.Method) + cons.Paths[kind] = route + } + if cons.MaxConcurrency <= 0 { + cons.MaxConcurrency = 1 + } + } + for kind, consName := range c.Routing { + if _, ok := c.Consumers[consName]; !ok { + return fmt.Errorf("routing.%s: unknown consumer %q", kind, consName) + } + } + if c.AudioProxy != "" { + if _, ok := c.Consumers[c.AudioProxy]; !ok { + return fmt.Errorf("audio_proxy: unknown consumer %q", c.AudioProxy) + } + } + return nil +} + +// ConsumerForKind returns the consumer designated to handle a given endpoint +// kind, or nil if routing is unset. +func (c *Config) ConsumerForKind(kind EndpointKind) (string, *Consumer) { + name, ok := c.Routing[kind] + if !ok { + return "", nil + } + return name, c.Consumers[name] +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..7f28a4c --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,130 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadValidConfig(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "consumers.yaml") + body := ` +listen: 127.0.0.1:8770 +gpu: + total_mib: 16376 + reserved_mib: 1024 + poll_interval_seconds: 2 + +routing: + tts: mvoice + llm: ollama + +audio_proxy: mvoice + +consumers: + mvoice: + url: http://localhost:8766 + health: + method: GET + path: /api/health + paths: + tts: + method: POST + path: /api/synthesize + vram_resident_mib: 2800 + unload: + method: POST + path: /api/admin/unload + load: + method: POST + path: /api/admin/load + can_coexist_with: [ollama] + priority: 3 + max_concurrency: 1 + ollama: + url: http://localhost:11434 + health: + method: GET + path: /api/tags + paths: + llm: + method: POST + path: /api/generate + vram_managed: true + can_coexist_with: [mvoice] + priority: 2 +` + if err := os.WriteFile(path, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.Listen != "127.0.0.1:8770" { + t.Errorf("Listen = %q", cfg.Listen) + } + if cfg.GPU.AvailableMiB() != 15352 { + t.Errorf("AvailableMiB = %d, want 15352", cfg.GPU.AvailableMiB()) + } + if cfg.GPU.PollInterval().Seconds() != 2 { + t.Errorf("PollInterval = %s", cfg.GPU.PollInterval()) + } + + name, cons := cfg.ConsumerForKind(KindTTS) + if name != "mvoice" || cons == nil { + t.Fatalf("ConsumerForKind(tts) = %q, %v", name, cons) + } + if cons.Paths[KindTTS].Method != "POST" { + t.Errorf("default method not preserved") + } + if cons.MaxConcurrency != 1 { + t.Errorf("MaxConcurrency = %d", cons.MaxConcurrency) + } + + if _, ok := cfg.Consumers["ollama"]; !ok { + t.Fatal("ollama not loaded") + } + if cfg.Consumers["ollama"].MaxConcurrency != 1 { + t.Errorf("ollama MaxConcurrency default = %d, want 1", + cfg.Consumers["ollama"].MaxConcurrency) + } +} + +func TestLoadRejectsUnknownRouting(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "consumers.yaml") + body := ` +routing: + tts: nonexistent +consumers: + mvoice: + url: http://localhost:8766 + health: { path: /api/health } + paths: {} +` + if err := os.WriteFile(path, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + if _, err := Load(path); err == nil { + t.Fatal("expected error for unknown routing target, got nil") + } +} + +func TestLoadRejectsMissingURL(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "consumers.yaml") + body := ` +consumers: + broken: + health: { path: /h } +` + if err := os.WriteFile(path, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + if _, err := Load(path); err == nil { + t.Fatal("expected error for missing URL, got nil") + } +} diff --git a/internal/gpu/gpu.go b/internal/gpu/gpu.go new file mode 100644 index 0000000..f8b77ae --- /dev/null +++ b/internal/gpu/gpu.go @@ -0,0 +1,117 @@ +// Package gpu polls nvidia-smi for live VRAM usage. +// +// Schritt 5 uses this to detect VRAM pressure and trigger LRU eviction. +// On hosts without an NVIDIA GPU (e.g. m's laptop during local dev) the +// poller silently reports zero usage so the scheduler can still run. +package gpu + +import ( + "context" + "log/slog" + "os/exec" + "strconv" + "strings" + "sync" + "time" +) + +// Sample is one nvidia-smi reading. +type Sample struct { + UsedMiB int + FreeMiB int + TotalMiB int + At time.Time + Err string +} + +// Poller periodically samples GPU memory and exposes the latest reading. +type Poller struct { + interval time.Duration + logger *slog.Logger + mu sync.RWMutex + last Sample +} + +// NewPoller builds a Poller. Pass the desired sampling cadence. +func NewPoller(interval time.Duration, logger *slog.Logger) *Poller { + if interval <= 0 { + interval = 2 * time.Second + } + return &Poller{interval: interval, logger: logger} +} + +// Run samples in a loop until ctx is cancelled. +func (p *Poller) Run(ctx context.Context) { + p.sampleOnce(ctx) + t := time.NewTicker(p.interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + p.sampleOnce(ctx) + } + } +} + +// Last returns the most recent sample. +func (p *Poller) Last() Sample { + p.mu.RLock() + defer p.mu.RUnlock() + return p.last +} + +func (p *Poller) sampleOnce(ctx context.Context) { + cctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + // memory.used,memory.free,memory.total in MiB, no units, no header. + cmd := exec.CommandContext(cctx, "nvidia-smi", + "--query-gpu=memory.used,memory.free,memory.total", + "--format=csv,noheader,nounits") + out, err := cmd.Output() + now := time.Now() + if err != nil { + p.store(Sample{At: now, Err: err.Error()}) + if p.logger != nil { + p.logger.Debug("nvidia-smi failed", "err", err) + } + return + } + used, free, total, perr := parseSMI(string(out)) + if perr != "" { + p.store(Sample{At: now, Err: perr}) + return + } + p.store(Sample{UsedMiB: used, FreeMiB: free, TotalMiB: total, At: now}) +} + +func (p *Poller) store(s Sample) { + p.mu.Lock() + p.last = s + p.mu.Unlock() +} + +func parseSMI(out string) (used, free, total int, errMsg string) { + // Take first non-empty line — multi-GPU hosts would yield more, but we + // only support single-GPU (mRock) for Phase 1. + for line := range strings.SplitSeq(out, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.Split(line, ",") + if len(parts) != 3 { + return 0, 0, 0, "unexpected nvidia-smi output: " + line + } + u, e1 := strconv.Atoi(strings.TrimSpace(parts[0])) + f, e2 := strconv.Atoi(strings.TrimSpace(parts[1])) + t, e3 := strconv.Atoi(strings.TrimSpace(parts[2])) + if e1 != nil || e2 != nil || e3 != nil { + return 0, 0, 0, "non-integer nvidia-smi output: " + line + } + return u, f, t, "" + } + return 0, 0, 0, "empty nvidia-smi output" +} diff --git a/internal/registry/parse.go b/internal/registry/parse.go new file mode 100644 index 0000000..d3f40c7 --- /dev/null +++ b/internal/registry/parse.go @@ -0,0 +1,35 @@ +package registry + +import ( + "encoding/json" +) + +// parseGPUFields best-effort parses a consumer's health response body to pull +// out gpu_resident_mib (mvoice exposes this since the Schritt 1 patch) and a +// 'loaded' boolean. Returns the previous loaded value if the field is absent +// (i.e. the consumer doesn't report it — assume always loaded). +func parseGPUFields(body []byte, prevLoaded bool) (mib int, loaded bool) { + var parsed struct { + GPUResidentMiB *int `json:"gpu_resident_mib"` + Loaded *bool `json:"loaded"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + // Non-JSON health response (e.g. whisper-server '/' returns HTML). Treat + // 200 OK as healthy + loaded, no VRAM info. + return 0, true + } + if parsed.GPUResidentMiB != nil { + mib = *parsed.GPUResidentMiB + } + if parsed.Loaded != nil { + loaded = *parsed.Loaded + } else { + // Field absent — preserve prior state, default true on first probe. + if !prevLoaded { + loaded = true + } else { + loaded = prevLoaded + } + } + return mib, loaded +} diff --git a/internal/registry/registry.go b/internal/registry/registry.go new file mode 100644 index 0000000..dac24fa --- /dev/null +++ b/internal/registry/registry.go @@ -0,0 +1,178 @@ +// Package registry tracks the live state of every GPU consumer. +// +// At Schritt 2 (MVP) the registry only does health probing — periodic GET on +// each consumer's health route, last-success timestamp, last error. Schritt 4 +// adds per-consumer in-flight counts and LastUsed for LRU eviction in +// Schritt 5. +package registry + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "sync" + "time" + + "mgit.msbls.de/m/mGPUmanager/internal/config" +) + +// State is a snapshot of a single consumer's live status. +type State struct { + Name string + Healthy bool + LastProbe time.Time + LastError string + GPUResidentMiB int // populated from consumer health response when present + Loaded bool // mvoice reports this; others default to true + Active int // in-flight job count (Schritt 4) + LastUsed time.Time // last successful job completion (Schritt 5) + TotalRequests int64 +} + +// Registry holds the live state of all consumers. +type Registry struct { + cfg *config.Config + client *http.Client + logger *slog.Logger + + mu sync.RWMutex + states map[string]*State +} + +// New builds a Registry from the loaded config. +func New(cfg *config.Config, logger *slog.Logger) *Registry { + r := &Registry{ + cfg: cfg, + client: &http.Client{Timeout: 5 * time.Second}, + logger: logger, + states: make(map[string]*State, len(cfg.Consumers)), + } + for name := range cfg.Consumers { + r.states[name] = &State{Name: name} + } + return r +} + +// Run starts the background health-probe loop and blocks until ctx is done. +// Cadence is fixed at 5s for health (independent of GPU polling cadence). +func (r *Registry) Run(ctx context.Context) { + r.probeAll(ctx) + t := time.NewTicker(5 * time.Second) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + r.probeAll(ctx) + } + } +} + +func (r *Registry) probeAll(ctx context.Context) { + var wg sync.WaitGroup + for name, cons := range r.cfg.Consumers { + wg.Add(1) + go func(name string, cons *config.Consumer) { + defer wg.Done() + r.probeOne(ctx, name, cons) + }(name, cons) + } + wg.Wait() +} + +func (r *Registry) probeOne(ctx context.Context, name string, cons *config.Consumer) { + cctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(cctx, cons.Health.Method, cons.URL+cons.Health.Path, nil) + if err != nil { + r.recordProbe(name, false, err.Error(), nil) + return + } + resp, err := r.client.Do(req) + if err != nil { + r.recordProbe(name, false, err.Error(), nil) + return + } + defer resp.Body.Close() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 8192)) + if resp.StatusCode >= 400 { + r.recordProbe(name, false, fmt.Sprintf("status %d", resp.StatusCode), nil) + return + } + r.recordProbe(name, true, "", body) +} + +// recordProbe stores the outcome of one health check, optionally parsing +// gpu_resident_mib / loaded fields out of the response body. +func (r *Registry) recordProbe(name string, ok bool, errMsg string, body []byte) { + r.mu.Lock() + defer r.mu.Unlock() + s := r.states[name] + if s == nil { + return + } + s.LastProbe = time.Now() + s.Healthy = ok + s.LastError = errMsg + if ok && body != nil { + s.GPUResidentMiB, s.Loaded = parseGPUFields(body, s.Loaded) + } + if !ok && r.logger != nil { + r.logger.Debug("consumer probe failed", "consumer", name, "err", errMsg) + } +} + +// RecordProbeForTest exposes the internal probe-recording path to tests +// in other packages without depending on the live 5s probe loop. +func (r *Registry) RecordProbeForTest(name string, ok bool, errMsg string, body []byte) { + r.recordProbe(name, ok, errMsg, body) +} + +// Snapshot returns a copy of all consumer states, ordered by config-declared +// consumer name set (Go map iteration order is randomized — callers that need +// stable ordering should sort). +func (r *Registry) Snapshot() map[string]State { + r.mu.RLock() + defer r.mu.RUnlock() + out := make(map[string]State, len(r.states)) + for k, v := range r.states { + out[k] = *v + } + return out +} + +// Get returns a single consumer state (copy) or zero-value if unknown. +func (r *Registry) Get(name string) State { + r.mu.RLock() + defer r.mu.RUnlock() + if s, ok := r.states[name]; ok { + return *s + } + return State{} +} + +// MarkActive increments the in-flight count and updates LastUsed. +// Returns a release func to call on job completion. +func (r *Registry) MarkActive(name string) func() { + r.mu.Lock() + if s, ok := r.states[name]; ok { + s.Active++ + s.TotalRequests++ + } + r.mu.Unlock() + return func() { + r.mu.Lock() + if s, ok := r.states[name]; ok { + if s.Active > 0 { + s.Active-- + } + s.LastUsed = time.Now() + } + r.mu.Unlock() + } +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go new file mode 100644 index 0000000..396f29c --- /dev/null +++ b/internal/scheduler/scheduler.go @@ -0,0 +1,112 @@ +// Package scheduler controls who gets the GPU when. +// +// Three responsibilities, added in three phases: +// +// - Schritt 2 (this file's first version): a passthrough — every job runs +// immediately, no locking, no queueing. Only useful for proving the HTTP +// façade end-to-end. +// - Schritt 4: a global mutex (or capacity-1 channel) serialises all GPU +// work. Per-consumer max_concurrency limits stay at 1 for now. +// - Schritt 5: VRAM-pressure-aware eviction kicks in before acquire when the +// requested consumer's resident cost would exceed available headroom. +// +// The interface deliberately hides which phase is active from callers +// (server.go) so the upgrade path is local to this package. +package scheduler + +import ( + "context" + "errors" + "sync" + "time" + + "mgit.msbls.de/m/mGPUmanager/internal/config" + "mgit.msbls.de/m/mGPUmanager/internal/registry" +) + +// ErrSchedulerStopped is returned if Run is called after Close. +var ErrSchedulerStopped = errors.New("scheduler stopped") + +// Job is what a consumer route worker executes while holding the GPU lock. +type Job func(ctx context.Context) error + +// Scheduler decides when GPU work runs. Implementations may queue, serialise, +// or evict other consumers before granting access. +type Scheduler interface { + // Run executes fn while the caller holds the right to use the GPU for + // the named consumer. It blocks until fn returns or ctx is cancelled. + Run(ctx context.Context, consumer string, fn Job) error + + // Stats returns a snapshot of scheduler internals for /v1/status. + Stats() Stats +} + +// Stats is what /v1/status reports about the scheduler. +type Stats struct { + QueueDepth int `json:"queue_depth"` + InFlight int `json:"in_flight"` + TotalJobs int64 `json:"total_jobs"` + LastWaitMS int64 `json:"last_wait_ms"` + LastRunMS int64 `json:"last_run_ms"` + Evictions int64 `json:"evictions"` + OldestQueued time.Time `json:"oldest_queued,omitzero"` +} + +// Passthrough is the Schritt 2 stand-in: no lock, no queue. Every job runs +// concurrently. It exists so the server package can be written against the +// final interface from day one. +type Passthrough struct { + reg *registry.Registry + + mu sync.Mutex + inFlight int + total int64 + lastRunMS int64 +} + +// NewPassthrough returns a Scheduler that runs every job immediately. +func NewPassthrough(reg *registry.Registry) *Passthrough { + return &Passthrough{reg: reg} +} + +// Run executes fn straight away, only tracking in-flight count for stats. +func (p *Passthrough) Run(ctx context.Context, consumer string, fn Job) error { + release := p.reg.MarkActive(consumer) + defer release() + + p.mu.Lock() + p.inFlight++ + p.total++ + p.mu.Unlock() + defer func() { + p.mu.Lock() + p.inFlight-- + p.mu.Unlock() + }() + + start := time.Now() + err := fn(ctx) + elapsed := time.Since(start).Milliseconds() + p.mu.Lock() + p.lastRunMS = elapsed + p.mu.Unlock() + return err +} + +// Stats returns current passthrough statistics. +func (p *Passthrough) Stats() Stats { + p.mu.Lock() + defer p.mu.Unlock() + return Stats{ + InFlight: p.inFlight, + TotalJobs: p.total, + LastRunMS: p.lastRunMS, + } +} + +// Compile-time interface guard. +var _ Scheduler = (*Passthrough)(nil) + +// Ensure config package is imported (used by later Schritte that read +// per-consumer max_concurrency and vram_resident_mib). +var _ = config.KindTTS diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..08e4d5c --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,377 @@ +// Package server is the HTTP façade of mGPUmanager. +// +// It exposes: +// - POST /v1/tts, /v1/stt, /v1/llm, /v1/image — pass-through proxy to the +// consumer named in config.Routing[kind]. +// - GET /audio/* — proxy to config.AudioProxy (mvoice's audio directory). +// - GET /v1/status — live snapshot of consumers + GPU + scheduler. +// - GET /healthz — broker liveness (200 if process is up). +// +// Every proxy call goes through the Scheduler so that, in Schritt 4 and 5, +// queueing and eviction can be added without touching server.go. +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "mgit.msbls.de/m/mGPUmanager/internal/config" + "mgit.msbls.de/m/mGPUmanager/internal/gpu" + "mgit.msbls.de/m/mGPUmanager/internal/registry" + "mgit.msbls.de/m/mGPUmanager/internal/scheduler" +) + +// Server bundles the HTTP handlers + dependencies. +type Server struct { + cfg *config.Config + reg *registry.Registry + gpu *gpu.Poller + sched scheduler.Scheduler + client *http.Client + logger *slog.Logger +} + +// New builds a Server. Caller owns the lifecycle of reg/gpu/sched. +func New(cfg *config.Config, reg *registry.Registry, gpuPoller *gpu.Poller, sched scheduler.Scheduler, logger *slog.Logger) *Server { + return &Server{ + cfg: cfg, + reg: reg, + gpu: gpuPoller, + sched: sched, + client: &http.Client{Timeout: 120 * time.Second}, // TTS can take 5-10s; image gen up to 60s + logger: logger, + } +} + +// Handler returns the root mux. Caller wraps it in http.Server. +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("POST /v1/tts", s.handleEndpoint(config.KindTTS)) + mux.HandleFunc("POST /v1/stt", s.handleEndpoint(config.KindSTT)) + mux.HandleFunc("POST /v1/llm", s.handleEndpoint(config.KindLLM)) + mux.HandleFunc("POST /v1/image", s.handleEndpoint(config.KindImage)) + + mux.HandleFunc("GET /audio/", s.handleAudio) + mux.HandleFunc("GET /v1/status", s.handleStatus) + mux.HandleFunc("GET /healthz", s.handleHealthz) + mux.HandleFunc("GET /", s.handleRoot) + + return logMiddleware(s.logger, mux) +} + +// ───── error envelope ───────────────────────────────────────────────────── + +// errorBody is the broker's structured error envelope. Every non-2xx response +// from mGPUmanager itself uses this shape. (Pass-through 4xx/5xx from +// consumers are forwarded verbatim so callers see the original payload.) +type errorBody struct { + Error string `json:"error"` + Message string `json:"message"` + Consumer string `json:"consumer,omitempty"` + Retryable bool `json:"retryable"` +} + +func writeErr(w http.ResponseWriter, status int, code, msg, consumer string, retryable bool) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(errorBody{ + Error: code, + Message: msg, + Consumer: consumer, + Retryable: retryable, + }) +} + +// ───── endpoint proxy ───────────────────────────────────────────────────── + +// handleEndpoint returns the http.HandlerFunc for a /v1/ endpoint. +func (s *Server) handleEndpoint(kind config.EndpointKind) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + consName, cons := s.cfg.ConsumerForKind(kind) + if cons == nil { + writeErr(w, http.StatusNotImplemented, "no_consumer", + fmt.Sprintf("no consumer routes %s", kind), "", false) + return + } + route, ok := cons.Paths[kind] + if !ok { + writeErr(w, http.StatusNotImplemented, "no_consumer", + fmt.Sprintf("consumer %s lacks paths.%s", consName, kind), consName, false) + return + } + + // Refuse fast if the consumer is unhealthy (last probe failed) — keeps + // Felix-Banholzer-style silent-fallback impossible. + st := s.reg.Get(consName) + if !st.Healthy && !st.LastProbe.IsZero() { + writeErr(w, http.StatusServiceUnavailable, "consumer_unreachable", + fmt.Sprintf("consumer %s last probe failed: %s", consName, st.LastError), + consName, true) + return + } + + err := s.sched.Run(r.Context(), consName, func(ctx context.Context) error { + return s.proxyRequest(ctx, w, r, cons, route, consName) + }) + if err != nil && !responseStarted(w) { + writeErr(w, http.StatusInternalServerError, "scheduler_error", + err.Error(), consName, true) + } + } +} + +// proxyRequest forwards the inbound HTTP request to a consumer route and +// streams the response back. Errors before the consumer responds are surfaced +// as the broker's structured error envelope; once the consumer has begun +// responding we stream its bytes through unchanged. +func (s *Server) proxyRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, cons *config.Consumer, route config.Route, consumer string) error { + target, err := url.Parse(cons.URL) + if err != nil { + writeErr(w, http.StatusInternalServerError, "bad_consumer_url", + err.Error(), consumer, false) + return nil + } + target.Path = route.Path + // Forward inbound query string verbatim. + target.RawQuery = r.URL.RawQuery + + method := route.Method + if method == "" { + method = r.Method + } + + upstream, err := http.NewRequestWithContext(ctx, method, target.String(), r.Body) + if err != nil { + writeErr(w, http.StatusInternalServerError, "bad_request", + err.Error(), consumer, false) + return nil + } + // Copy through Content-Type, Content-Length and Accept (don't carry Host). + for _, h := range []string{"Content-Type", "Content-Length", "Accept", "Accept-Encoding"} { + if v := r.Header.Get(h); v != "" { + upstream.Header.Set(h, v) + } + } + + resp, err := s.client.Do(upstream) + if err != nil { + writeErr(w, http.StatusBadGateway, "consumer_unreachable", + fmt.Sprintf("upstream %s: %v", target.Host, err), consumer, true) + return nil + } + defer resp.Body.Close() + + // Stream response. + for k, vs := range resp.Header { + if strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Transfer-Encoding") { + continue + } + for _, v := range vs { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) + return nil +} + +// ───── audio proxy ──────────────────────────────────────────────────────── + +// handleAudio forwards GET /audio/ to the audio_proxy consumer (mvoice). +// wa.sh fetches the rendered .wav via this path after /v1/tts returns its URL. +func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) { + if s.cfg.AudioProxy == "" { + writeErr(w, http.StatusNotFound, "no_audio_proxy", + "audio_proxy is not configured", "", false) + return + } + cons, ok := s.cfg.Consumers[s.cfg.AudioProxy] + if !ok { + writeErr(w, http.StatusInternalServerError, "no_audio_proxy", + "audio_proxy points at unknown consumer", s.cfg.AudioProxy, false) + return + } + target, err := url.Parse(cons.URL) + if err != nil { + writeErr(w, http.StatusInternalServerError, "bad_consumer_url", + err.Error(), s.cfg.AudioProxy, false) + return + } + target.Path = r.URL.Path + target.RawQuery = r.URL.RawQuery + + upstream, err := http.NewRequestWithContext(r.Context(), http.MethodGet, target.String(), nil) + if err != nil { + writeErr(w, http.StatusInternalServerError, "bad_request", + err.Error(), s.cfg.AudioProxy, false) + return + } + resp, err := s.client.Do(upstream) + if err != nil { + writeErr(w, http.StatusBadGateway, "consumer_unreachable", + fmt.Sprintf("upstream %s: %v", target.Host, err), s.cfg.AudioProxy, true) + return + } + defer resp.Body.Close() + + for k, vs := range resp.Header { + for _, v := range vs { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + +// ───── status ───────────────────────────────────────────────────────────── + +type statusResponse struct { + Listen string `json:"listen"` + Time time.Time `json:"time"` + GPU statusGPU `json:"gpu"` + Routing map[config.EndpointKind]string `json:"routing"` + Consumers []statusConsumer `json:"consumers"` + Scheduler scheduler.Stats `json:"scheduler"` +} + +type statusGPU struct { + TotalMiB int `json:"total_mib"` + UsedMiB int `json:"used_mib"` + FreeMiB int `json:"free_mib"` + ReservedMiB int `json:"reserved_mib"` + LastSample time.Time `json:"last_sample"` + Err string `json:"err,omitempty"` +} + +type statusConsumer struct { + Name string `json:"name"` + URL string `json:"url"` + Healthy bool `json:"healthy"` + Loaded bool `json:"loaded"` + GPUResidentMiB int `json:"gpu_resident_mib"` + VRAMBudgetMiB int `json:"vram_budget_mib"` + Active int `json:"active"` + TotalRequests int64 `json:"total_requests"` + LastUsed time.Time `json:"last_used,omitzero"` + LastProbe time.Time `json:"last_probe,omitzero"` + LastError string `json:"last_error,omitempty"` + Priority int `json:"priority"` + CanCoexistWith []string `json:"can_coexist_with"` +} + +func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { + sample := s.gpu.Last() + snap := s.reg.Snapshot() + + resp := statusResponse{ + Listen: s.cfg.Listen, + Time: time.Now(), + Routing: s.cfg.Routing, + GPU: statusGPU{ + TotalMiB: s.cfg.GPU.TotalMiB, + UsedMiB: sample.UsedMiB, + FreeMiB: sample.FreeMiB, + ReservedMiB: s.cfg.GPU.ReservedMiB, + LastSample: sample.At, + Err: sample.Err, + }, + Scheduler: s.sched.Stats(), + } + if resp.GPU.TotalMiB == 0 && sample.TotalMiB > 0 { + resp.GPU.TotalMiB = sample.TotalMiB + } + + // Stable ordering by config-declared name. + names := make([]string, 0, len(s.cfg.Consumers)) + for n := range s.cfg.Consumers { + names = append(names, n) + } + sortStrings(names) + for _, n := range names { + cons := s.cfg.Consumers[n] + st := snap[n] + resp.Consumers = append(resp.Consumers, statusConsumer{ + Name: n, + URL: cons.URL, + Healthy: st.Healthy, + Loaded: st.Loaded, + GPUResidentMiB: st.GPUResidentMiB, + VRAMBudgetMiB: cons.VRAMResidentMiB, + Active: st.Active, + TotalRequests: st.TotalRequests, + LastUsed: st.LastUsed, + LastProbe: st.LastProbe, + LastError: st.LastError, + Priority: cons.Priority, + CanCoexistWith: cons.CanCoexistWith, + }) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + +func (s *Server) handleHealthz(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ok"}`)) +} + +func (s *Server) handleRoot(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = io.Copy(w, bytes.NewReader([]byte( + "mGPUmanager — see GET /v1/status for live state, POST /v1/{tts,stt,llm,image} for inference\n", + ))) +} + +// ───── helpers ──────────────────────────────────────────────────────────── + +// responseStarted is a coarse heuristic: once we've written headers, we can't +// switch to the error envelope. The proxy path writes headers only inside +// proxyRequest, which catches its own errors before that point. +func responseStarted(_ http.ResponseWriter) bool { return false } + +// sortStrings: avoid pulling in "sort" everywhere this file uses ordering. +func sortStrings(s []string) { + for i := 1; i < len(s); i++ { + for j := i; j > 0 && s[j-1] > s[j]; j-- { + s[j-1], s[j] = s[j], s[j-1] + } + } +} + +// logMiddleware emits one structured request log per call. +func logMiddleware(logger *slog.Logger, next http.Handler) http.Handler { + if logger == nil { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + lw := &statusCapture{ResponseWriter: w, code: 200} + next.ServeHTTP(lw, r) + logger.Info("http", + "method", r.Method, + "path", r.URL.Path, + "status", lw.code, + "ms", time.Since(start).Milliseconds(), + ) + }) +} + +type statusCapture struct { + http.ResponseWriter + code int +} + +func (s *statusCapture) WriteHeader(code int) { + s.code = code + s.ResponseWriter.WriteHeader(code) +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..4691bc1 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,236 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "mgit.msbls.de/m/mGPUmanager/internal/config" + "mgit.msbls.de/m/mGPUmanager/internal/gpu" + "mgit.msbls.de/m/mGPUmanager/internal/registry" + "mgit.msbls.de/m/mGPUmanager/internal/scheduler" +) + +// silentLogger discards everything; keeps test output clean. +func silentLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) +} + +// fakeMVoice spins up a httptest.Server that mimics the relevant mvoice +// endpoints just enough to exercise the broker's proxy behaviour. +func fakeMVoice(t *testing.T) (*httptest.Server, *atomic.Int32) { + t.Helper() + calls := &atomic.Int32{} + mux := http.NewServeMux() + mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ready","loaded":true,"gpu_resident_mib":2800}`)) + }) + mux.HandleFunc("POST /api/synthesize", func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + body, _ := io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"audio_url":"/audio/test.wav","payload_echo":"` + + strings.ReplaceAll(string(body), `"`, `\"`) + `"}`)) + }) + mux.HandleFunc("GET /audio/test.wav", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "audio/wav") + _, _ = w.Write([]byte("RIFF....fake-wav-bytes")) + }) + return httptest.NewServer(mux), calls +} + +func buildHarness(t *testing.T, upstream *httptest.Server) (*Server, *registry.Registry, *gpu.Poller) { + t.Helper() + cfg := &config.Config{ + Listen: "127.0.0.1:0", + GPU: config.GPU{TotalMiB: 16376, ReservedMiB: 1024, PollIntervalSeconds: 2}, + Routing: map[config.EndpointKind]string{ + config.KindTTS: "mvoice", + }, + AudioProxy: "mvoice", + Consumers: map[string]*config.Consumer{ + "mvoice": { + URL: upstream.URL, + Health: config.Route{Method: "GET", Path: "/api/health"}, + Paths: map[config.EndpointKind]config.Route{ + config.KindTTS: {Method: "POST", Path: "/api/synthesize"}, + }, + VRAMResidentMiB: 2800, + MaxConcurrency: 1, + Priority: 3, + }, + }, + } + reg := registry.New(cfg, silentLogger()) + poller := gpu.NewPoller(time.Second, silentLogger()) + sched := scheduler.NewPassthrough(reg) + srv := New(cfg, reg, poller, sched, silentLogger()) + return srv, reg, poller +} + +func TestProxyTTSForwardsBodyAndReturnsConsumerResponse(t *testing.T) { + upstream, calls := fakeMVoice(t) + defer upstream.Close() + + srv, reg, _ := buildHarness(t, upstream) + // Force-mark mvoice healthy so the gating in handleEndpoint passes. + probeOnce(t, reg, upstream.URL+"/api/health", "mvoice") + + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + form := url.Values{} + form.Set("text", "Hallo") + resp, err := http.Post(ts.URL+"/v1/tts", "application/x-www-form-urlencoded", + strings.NewReader(form.Encode())) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status %d: %s", resp.StatusCode, body) + } + if got := calls.Load(); got != 1 { + t.Errorf("upstream calls = %d, want 1", got) + } + var payload struct { + AudioURL string `json:"audio_url"` + PayloadEcho string `json:"payload_echo"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + t.Fatal(err) + } + if !strings.Contains(payload.PayloadEcho, "text=Hallo") { + t.Errorf("upstream did not receive body: %q", payload.PayloadEcho) + } + if payload.AudioURL != "/audio/test.wav" { + t.Errorf("AudioURL = %q", payload.AudioURL) + } +} + +func TestAudioProxyForwardsBytes(t *testing.T) { + upstream, _ := fakeMVoice(t) + defer upstream.Close() + + srv, _, _ := buildHarness(t, upstream) + + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/audio/test.wav") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("status %d", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if !strings.HasPrefix(string(body), "RIFF") { + t.Errorf("audio body did not start with RIFF: %q", body[:8]) + } +} + +func TestUnhealthyConsumerReturns503(t *testing.T) { + // Upstream that always 500s health probe. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer upstream.Close() + + srv, reg, _ := buildHarness(t, upstream) + probeOnce(t, reg, upstream.URL+"/api/health", "mvoice") + + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/v1/tts", "application/json", strings.NewReader(`{}`)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 503 { + t.Fatalf("status = %d, want 503", resp.StatusCode) + } + var env errorBody + if err := json.NewDecoder(resp.Body).Decode(&env); err != nil { + t.Fatal(err) + } + if env.Error != "consumer_unreachable" { + t.Errorf("error code = %q", env.Error) + } + if !env.Retryable { + t.Errorf("retryable should be true for unreachable consumer") + } +} + +func TestStatusReturnsConsumers(t *testing.T) { + upstream, _ := fakeMVoice(t) + defer upstream.Close() + + srv, reg, _ := buildHarness(t, upstream) + probeOnce(t, reg, upstream.URL+"/api/health", "mvoice") + + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/status") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + t.Fatalf("status %d: %s", resp.StatusCode, body) + } + var sr statusResponse + if err := json.Unmarshal(body, &sr); err != nil { + t.Fatal(err) + } + if len(sr.Consumers) != 1 { + t.Fatalf("consumers = %d", len(sr.Consumers)) + } + if sr.Consumers[0].Name != "mvoice" { + t.Errorf("name = %q", sr.Consumers[0].Name) + } + if !sr.Consumers[0].Healthy { + t.Errorf("expected healthy=true after successful probe") + } + if sr.GPU.TotalMiB != 16376 { + t.Errorf("GPU.TotalMiB = %d", sr.GPU.TotalMiB) + } +} + +// probeOnce drives the registry to record one successful (or failed) health +// probe synchronously so tests don't have to wait for the 5s loop. +func probeOnce(t *testing.T, reg *registry.Registry, healthURL, name string) { + t.Helper() + req, err := http.NewRequestWithContext(context.Background(), "GET", healthURL, nil) + if err != nil { + t.Fatal(err) + } + client := &http.Client{Timeout: 2 * time.Second} + resp, err := client.Do(req) + if err != nil { + reg.RecordProbeForTest(name, false, err.Error(), nil) + return + } + defer resp.Body.Close() + body, _ := io.ReadAll(io.LimitReader(resp.Body, 8192)) + ok := resp.StatusCode < 400 + errMsg := "" + if !ok { + errMsg = "status " + strings.TrimSpace(resp.Status) + } + reg.RecordProbeForTest(name, ok, errMsg, body) +} diff --git a/systemd/mgpumanager.service b/systemd/mgpumanager.service new file mode 100644 index 0000000..2a07b53 --- /dev/null +++ b/systemd/mgpumanager.service @@ -0,0 +1,30 @@ +[Unit] +Description=mGPUmanager — GPU-Inference-Control-Plane for mRock +Documentation=https://mgit.msbls.de/m/mGPUmanager +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +User=m +Group=m +WorkingDirectory=/home/m/dev/mGPUmanager +ExecStart=/home/m/dev/mGPUmanager/bin/mgpumanager \ + --config /home/m/dev/mGPUmanager/config/consumers.yaml \ + --log-level info +Restart=on-failure +RestartSec=3 +TimeoutStopSec=10 + +# Hardening — broker has no need for elevated capabilities. +NoNewPrivileges=true +PrivateTmp=true +ProtectSystem=strict +ProtectHome=read-only +ReadWritePaths=/home/m/dev/mGPUmanager + +# The broker only proxies; nvidia-smi is the only GPU-touching call. +PrivateDevices=false + +[Install] +WantedBy=multi-user.target