Merge Phase 1 — Schritte 0–5 (broker MVP, queue, LRU eviction)
This commit is contained in:
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# Build artifacts
|
||||
/bin/
|
||||
|
||||
# Worker session noise
|
||||
.m/
|
||||
*.log
|
||||
|
||||
# Go test/coverage
|
||||
*.out
|
||||
coverage.html
|
||||
|
||||
# Editor cruft
|
||||
*.swp
|
||||
.idea/
|
||||
.vscode/
|
||||
35
Makefile
Normal file
35
Makefile
Normal file
@@ -0,0 +1,35 @@
|
||||
# mGPUmanager build + deploy targets.
|
||||
#
|
||||
# `make build` — compile the Go binary into ./bin/mgpumanager.
|
||||
# `make test` — go test ./...
|
||||
# `make run` — run locally against ./config/consumers.yaml.
|
||||
# `make deploy` — rsync binary + config + systemd unit to mRock,
|
||||
# reload systemd, restart the service.
|
||||
|
||||
BIN := bin/mgpumanager
|
||||
PKG := ./cmd/mgpumanager
|
||||
|
||||
GO ?= go
|
||||
HOST ?= mrock
|
||||
REMOTE_DIR ?= /home/m/dev/mGPUmanager
|
||||
|
||||
.PHONY: build test run deploy clean
|
||||
|
||||
build:
|
||||
mkdir -p bin
|
||||
$(GO) build -trimpath -ldflags="-s -w" -o $(BIN) $(PKG)
|
||||
|
||||
test:
|
||||
$(GO) test ./...
|
||||
|
||||
run: build
|
||||
./$(BIN) --config config/consumers.yaml --log-level debug
|
||||
|
||||
deploy: build
|
||||
rsync -a --mkpath $(BIN) $(HOST):$(REMOTE_DIR)/$(BIN)
|
||||
rsync -a --mkpath config/consumers.yaml $(HOST):$(REMOTE_DIR)/config/consumers.yaml
|
||||
rsync -a --mkpath systemd/mgpumanager.service $(HOST):$(REMOTE_DIR)/systemd/mgpumanager.service
|
||||
ssh $(HOST) "sudo cp $(REMOTE_DIR)/systemd/mgpumanager.service /etc/systemd/system/mgpumanager.service && sudo systemctl daemon-reload && sudo systemctl enable mgpumanager.service && sudo systemctl restart mgpumanager.service && sudo systemctl status mgpumanager.service --no-pager -l"
|
||||
|
||||
clean:
|
||||
rm -rf bin
|
||||
72
README.md
72
README.md
@@ -1,3 +1,73 @@
|
||||
# mGPUmanager
|
||||
|
||||
GPU-Inference-Control-Plane für mRock — Scheduler vor TTS/STT/LLM/Image-Gen mit globalem GPU-Lock + LRU-Eviction + einheitlicher /v1-Fassade. Konsumenten: mVoice, whisper-server, Ollama, ComfyUI/FLUX, später Furbotto. Go.
|
||||
GPU-Inference-Control-Plane für mRock — Scheduler vor TTS/STT/LLM/Image-Gen mit globalem GPU-Lock + LRU-Eviction + einheitlicher `/v1`-Fassade. Konsumenten: mVoice, whisper-server, Ollama, ComfyUI/FLUX, später Furbotto. Go.
|
||||
|
||||
Full design: [`docs/design.md`](docs/design.md) — Bestandsaufnahme, 10-Alternativen-Survey, Eviction-Algorithmus, Migrationspfad.
|
||||
|
||||
## Was es macht
|
||||
|
||||
Auf `mrock:8770` sitzt ein Go-Daemon, der:
|
||||
|
||||
- `/v1/tts`, `/v1/stt`, `/v1/llm`, `/v1/image` als einheitliche Konsumenten-Fassade exponiert,
|
||||
- jede Anfrage durch einen globalen GPU-Scheduler schleust (seriell, Queue),
|
||||
- bei VRAM-Druck LRU-Eviction über die deklarierten Coexistenz-Gruppen aus `config/consumers.yaml` fährt,
|
||||
- in `/v1/status` Live-GPU-Belegung + Consumer-Health + Scheduler-Statistiken zeigt,
|
||||
- niemals stille Fallbacks zurückgibt — Fehler kommen als strukturiertes `{error,message,consumer,retryable}`.
|
||||
|
||||
## Konsumenten-Registry
|
||||
|
||||
`config/consumers.yaml` deklariert pro Consumer:
|
||||
|
||||
- `url`, `health.{method,path}` für Liveness-Probing
|
||||
- `paths.<kind>.{method,path}` — wie der Broker zu seinem TTS/STT/LLM/Image-Endpoint kommt
|
||||
- `vram_resident_mib` — für die Scheduler-Mathe (Schritt 5)
|
||||
- `unload.{method,path,body}` und optional `load.{method,path}` — wie der Broker den Consumer aus dem VRAM räumt / wieder hochfährt
|
||||
- `can_coexist_with: [..]` — wer parallel resident sein darf
|
||||
- `priority` (0=low, 4=urgent), `max_concurrency`
|
||||
|
||||
## Build + Deploy
|
||||
|
||||
```sh
|
||||
make build # ./bin/mgpumanager
|
||||
make test # go test ./...
|
||||
make run # lokal gegen ./config/consumers.yaml
|
||||
make deploy HOST=mrock # rsync + systemd reload + restart
|
||||
```
|
||||
|
||||
Auf mRock läuft der Daemon als System-Unit (`/etc/systemd/system/mgpumanager.service`).
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Verb | Pfad | Verhalten |
|
||||
|---|---|---|
|
||||
| POST | `/v1/tts` | Proxy zu `routing.tts`-Consumer (default: mvoice `/api/synthesize`) |
|
||||
| POST | `/v1/stt` | Proxy zu `routing.stt`-Consumer (default: mvoice `/api/transcribe`) |
|
||||
| POST | `/v1/llm` | Proxy zu `routing.llm`-Consumer (default: ollama `/api/generate`) |
|
||||
| POST | `/v1/image` | Proxy zu `routing.image`-Consumer (default: comfyui `/prompt`) |
|
||||
| GET | `/audio/*` | Proxy zu `audio_proxy`-Consumer (wa.sh fetcht generiertes Audio so) |
|
||||
| GET | `/v1/status`| Live-Snapshot: GPU + Consumer-Health + Scheduler-Stats |
|
||||
| GET | `/healthz` | Broker-Liveness (200 OK) |
|
||||
|
||||
## Fehler-Schema
|
||||
|
||||
Jeder Broker-eigene Fehler hat die Form:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "consumer_unreachable",
|
||||
"message": "upstream mvoice last probe failed: connection refused",
|
||||
"consumer": "mvoice",
|
||||
"retryable": true
|
||||
}
|
||||
```
|
||||
|
||||
Codes: `consumer_unreachable`, `no_consumer`, `scheduler_error`, `bad_consumer_url`, `bad_request`. Pass-through-4xx/5xx vom Consumer landet unverändert beim Client.
|
||||
|
||||
## Phase 1 Status (Issue #1)
|
||||
|
||||
- ✅ Schritt 0 — ComfyUI persistent (`systemd: comfyui.service`)
|
||||
- ✅ Schritt 1 — `mvoice /api/admin/{load,unload}` (mai/knuth/admin-load-unload @ mVoice)
|
||||
- ✅ Schritt 2 — Routing-Façade + `/v1/status`
|
||||
- ✅ Schritt 3 — wa.sh auf Broker umgestellt (m/mAi `mai/knuth/wa-tts-broker`)
|
||||
- ✅ Schritt 4 — Queue + globaler GPU-Lock
|
||||
- ✅ Schritt 5 — Coexistenz-Gruppen + LRU-Eviction
|
||||
|
||||
106
cmd/mgpumanager/main.go
Normal file
106
cmd/mgpumanager/main.go
Normal file
@@ -0,0 +1,106 @@
|
||||
// mgpumanager is the GPU-Inference-Control-Plane for mRock.
|
||||
//
|
||||
// One Go binary that:
|
||||
// 1. Loads consumers.yaml.
|
||||
// 2. Probes every consumer's /health on a 5s cadence.
|
||||
// 3. Polls nvidia-smi every 2s for live VRAM usage (used by Schritt 5
|
||||
// eviction).
|
||||
// 4. Exposes /v1/{tts,stt,llm,image} as a thin proxy + /v1/status for
|
||||
// observability.
|
||||
// 5. Funnels every job through the Scheduler (passthrough today, queue +
|
||||
// eviction in Schritt 4-5).
|
||||
//
|
||||
// All client routing happens through this daemon — no consumer is reached
|
||||
// directly any more. wa.sh, ImaGen, m-CLI and Furbotto-Voice will all speak
|
||||
// to :8770/v1/*.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/gpu"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/scheduler"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "config/consumers.yaml", "path to consumers.yaml")
|
||||
listenOverride := flag.String("listen", "", "override listen address from config")
|
||||
logLevel := flag.String("log-level", "info", "log level: debug|info|warn|error")
|
||||
flag.Parse()
|
||||
|
||||
logger := newLogger(*logLevel)
|
||||
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
logger.Error("config load failed", "err", err, "path", *configPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
if *listenOverride != "" {
|
||||
cfg.Listen = *listenOverride
|
||||
}
|
||||
|
||||
logger.Info("starting mGPUmanager",
|
||||
"listen", cfg.Listen,
|
||||
"consumers", len(cfg.Consumers),
|
||||
"poll_interval", cfg.GPU.PollInterval(),
|
||||
)
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
reg := registry.New(cfg, logger.With("component", "registry"))
|
||||
gpuPoller := gpu.NewPoller(cfg.GPU.PollInterval(), logger.With("component", "gpu"))
|
||||
// Schritt 5: VRAM-pressure-aware scheduler. Wraps the global GPU lock
|
||||
// with eviction logic — see internal/scheduler/evicting.go.
|
||||
sched := scheduler.NewEvicting(cfg, reg, gpuPoller,
|
||||
logger.With("component", "scheduler"))
|
||||
|
||||
go reg.Run(ctx)
|
||||
go gpuPoller.Run(ctx)
|
||||
|
||||
srv := server.New(cfg, reg, gpuPoller, sched, logger.With("component", "server"))
|
||||
httpSrv := &http.Server{
|
||||
Addr: cfg.Listen,
|
||||
Handler: srv.Handler(),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = httpSrv.Shutdown(shutCtx)
|
||||
}()
|
||||
|
||||
if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("listen failed", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info("shutdown complete")
|
||||
}
|
||||
|
||||
func newLogger(level string) *slog.Logger {
|
||||
var lvl slog.Level
|
||||
switch level {
|
||||
case "debug":
|
||||
lvl = slog.LevelDebug
|
||||
case "warn":
|
||||
lvl = slog.LevelWarn
|
||||
case "error":
|
||||
lvl = slog.LevelError
|
||||
default:
|
||||
lvl = slog.LevelInfo
|
||||
}
|
||||
return slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: lvl}))
|
||||
}
|
||||
90
config/consumers.yaml
Normal file
90
config/consumers.yaml
Normal file
@@ -0,0 +1,90 @@
|
||||
listen: 127.0.0.1:8770
|
||||
|
||||
gpu:
|
||||
total_mib: 16376 # RTX 4070 Ti SUPER
|
||||
reserved_mib: 1024 # headroom for system/desktop
|
||||
poll_interval_seconds: 2
|
||||
|
||||
routing:
|
||||
tts: mvoice
|
||||
stt: mvoice # whisper-server is alternative if explicitly requested
|
||||
llm: ollama
|
||||
image: comfyui
|
||||
|
||||
# Audio download proxy: any GET under audio_path_prefix is forwarded to this
|
||||
# consumer at the same path. wa.sh fetches mvoice's generated WAV this way.
|
||||
audio_proxy: mvoice
|
||||
audio_path_prefix: /api/audio/
|
||||
|
||||
consumers:
|
||||
mvoice:
|
||||
url: http://localhost:8766
|
||||
health:
|
||||
method: GET
|
||||
path: /api/health
|
||||
paths:
|
||||
tts:
|
||||
method: POST
|
||||
path: /api/synthesize
|
||||
stt:
|
||||
method: POST
|
||||
path: /api/transcribe
|
||||
vram_resident_mib: 2800
|
||||
load:
|
||||
method: POST
|
||||
path: /api/admin/load
|
||||
unload:
|
||||
method: POST
|
||||
path: /api/admin/unload
|
||||
can_coexist_with: [whisper-server, ollama]
|
||||
priority: 3
|
||||
max_concurrency: 1
|
||||
|
||||
whisper-server:
|
||||
url: http://localhost:8178
|
||||
health:
|
||||
method: GET
|
||||
path: /
|
||||
paths:
|
||||
stt:
|
||||
method: POST
|
||||
path: /inference
|
||||
vram_resident_mib: 2050
|
||||
# No HTTP unload; mGPUmanager evicts via systemd restart (Schritt 5).
|
||||
systemd_unit: whisper-server.service
|
||||
can_coexist_with: [mvoice, ollama]
|
||||
priority: 2
|
||||
max_concurrency: 1
|
||||
|
||||
ollama:
|
||||
url: http://localhost:11434
|
||||
health:
|
||||
method: GET
|
||||
path: /api/tags
|
||||
paths:
|
||||
llm:
|
||||
method: POST
|
||||
path: /api/generate
|
||||
# Ollama runs its own LRU keep_alive; we don't track resident VRAM.
|
||||
vram_managed: true
|
||||
can_coexist_with: [mvoice, whisper-server]
|
||||
priority: 2
|
||||
max_concurrency: 1
|
||||
|
||||
comfyui:
|
||||
url: http://localhost:8188
|
||||
health:
|
||||
method: GET
|
||||
path: /system_stats
|
||||
paths:
|
||||
image:
|
||||
method: POST
|
||||
path: /prompt
|
||||
vram_resident_mib: 13000
|
||||
unload:
|
||||
method: POST
|
||||
path: /api/free
|
||||
body: '{"unload_models":true,"free_memory":true}'
|
||||
can_coexist_with: []
|
||||
priority: 1
|
||||
max_concurrency: 1
|
||||
5
go.mod
Normal file
5
go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module mgit.msbls.de/m/mGPUmanager
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require gopkg.in/yaml.v3 v3.0.1
|
||||
4
go.sum
Normal file
4
go.sum
Normal file
@@ -0,0 +1,4 @@
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
172
internal/config/config.go
Normal file
172
internal/config/config.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// Package config loads the mGPUmanager consumer registry from YAML.
|
||||
//
|
||||
// The consumers.yaml file declares every GPU consumer (mvoice, whisper-server,
|
||||
// ollama, comfyui), how to route the four logical endpoint kinds (tts, stt,
|
||||
// llm, image) to a consumer, how to probe its health, and how to load/unload
|
||||
// it from VRAM. The scheduler (Schritt 4–5) reads vram_resident_mib +
|
||||
// can_coexist_with to drive eviction.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// EndpointKind enumerates the four logical broker endpoints exposed on /v1/*.
|
||||
type EndpointKind string
|
||||
|
||||
const (
|
||||
KindTTS EndpointKind = "tts"
|
||||
KindSTT EndpointKind = "stt"
|
||||
KindLLM EndpointKind = "llm"
|
||||
KindImage EndpointKind = "image"
|
||||
)
|
||||
|
||||
// AllKinds is the canonical ordering used by /v1/status and tests.
|
||||
var AllKinds = []EndpointKind{KindTTS, KindSTT, KindLLM, KindImage}
|
||||
|
||||
// Route describes an HTTP method + path on a consumer.
|
||||
type Route struct {
|
||||
Method string `yaml:"method"`
|
||||
Path string `yaml:"path"`
|
||||
// Body is an optional fixed request body for admin operations
|
||||
// (e.g. ComfyUI's /api/free expects {"unload_models":true,"free_memory":true}).
|
||||
Body string `yaml:"body,omitempty"`
|
||||
}
|
||||
|
||||
// Consumer describes a single GPU consumer behind the broker.
|
||||
type Consumer struct {
|
||||
URL string `yaml:"url"`
|
||||
Health Route `yaml:"health"`
|
||||
Paths map[EndpointKind]Route `yaml:"paths"`
|
||||
VRAMResidentMiB int `yaml:"vram_resident_mib"`
|
||||
VRAMManaged bool `yaml:"vram_managed"` // self-managed LRU (ollama)
|
||||
Load *Route `yaml:"load,omitempty"`
|
||||
Unload *Route `yaml:"unload,omitempty"`
|
||||
SystemdUnit string `yaml:"systemd_unit,omitempty"` // fallback unload (whisper-server)
|
||||
CanCoexistWith []string `yaml:"can_coexist_with"`
|
||||
Priority int `yaml:"priority"`
|
||||
MaxConcurrency int `yaml:"max_concurrency"`
|
||||
}
|
||||
|
||||
// GPU describes the host's GPU envelope.
|
||||
type GPU struct {
|
||||
TotalMiB int `yaml:"total_mib"`
|
||||
ReservedMiB int `yaml:"reserved_mib"`
|
||||
PollIntervalSeconds int `yaml:"poll_interval_seconds"`
|
||||
}
|
||||
|
||||
// PollInterval returns the GPU polling cadence as a Duration. Defaults to 2s.
|
||||
func (g GPU) PollInterval() time.Duration {
|
||||
if g.PollIntervalSeconds <= 0 {
|
||||
return 2 * time.Second
|
||||
}
|
||||
return time.Duration(g.PollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
// AvailableMiB returns total VRAM minus the system-reserved headroom.
|
||||
func (g GPU) AvailableMiB() int {
|
||||
if g.TotalMiB <= 0 {
|
||||
return 0
|
||||
}
|
||||
avail := g.TotalMiB - g.ReservedMiB
|
||||
if avail < 0 {
|
||||
return 0
|
||||
}
|
||||
return avail
|
||||
}
|
||||
|
||||
// Config is the parsed mGPUmanager configuration.
|
||||
type Config struct {
|
||||
Listen string `yaml:"listen"`
|
||||
GPU GPU `yaml:"gpu"`
|
||||
Routing map[EndpointKind]string `yaml:"routing"`
|
||||
AudioProxy string `yaml:"audio_proxy"`
|
||||
AudioPathPrefix string `yaml:"audio_path_prefix"`
|
||||
Consumers map[string]*Consumer `yaml:"consumers"`
|
||||
}
|
||||
|
||||
// Load reads and validates a consumers.yaml file from disk.
|
||||
func Load(path string) (*Config, error) {
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse %s: %w", path, err)
|
||||
}
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate %s: %w", path, err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if c.Listen == "" {
|
||||
c.Listen = "127.0.0.1:8770"
|
||||
}
|
||||
if len(c.Consumers) == 0 {
|
||||
return fmt.Errorf("no consumers declared")
|
||||
}
|
||||
for name, cons := range c.Consumers {
|
||||
if cons.URL == "" {
|
||||
return fmt.Errorf("consumer %q: url is required", name)
|
||||
}
|
||||
if _, err := url.Parse(cons.URL); err != nil {
|
||||
return fmt.Errorf("consumer %q: invalid url %q: %w", name, cons.URL, err)
|
||||
}
|
||||
if cons.Health.Path == "" {
|
||||
return fmt.Errorf("consumer %q: health.path is required", name)
|
||||
}
|
||||
if cons.Health.Method == "" {
|
||||
cons.Health.Method = "GET"
|
||||
}
|
||||
cons.Health.Method = strings.ToUpper(cons.Health.Method)
|
||||
for kind, route := range cons.Paths {
|
||||
if route.Path == "" {
|
||||
return fmt.Errorf("consumer %q: paths.%s.path is required", name, kind)
|
||||
}
|
||||
if route.Method == "" {
|
||||
route.Method = "POST"
|
||||
}
|
||||
route.Method = strings.ToUpper(route.Method)
|
||||
cons.Paths[kind] = route
|
||||
}
|
||||
if cons.MaxConcurrency <= 0 {
|
||||
cons.MaxConcurrency = 1
|
||||
}
|
||||
}
|
||||
for kind, consName := range c.Routing {
|
||||
if _, ok := c.Consumers[consName]; !ok {
|
||||
return fmt.Errorf("routing.%s: unknown consumer %q", kind, consName)
|
||||
}
|
||||
}
|
||||
if c.AudioProxy != "" {
|
||||
if _, ok := c.Consumers[c.AudioProxy]; !ok {
|
||||
return fmt.Errorf("audio_proxy: unknown consumer %q", c.AudioProxy)
|
||||
}
|
||||
if c.AudioPathPrefix == "" {
|
||||
c.AudioPathPrefix = "/api/audio/"
|
||||
}
|
||||
if !strings.HasPrefix(c.AudioPathPrefix, "/") || !strings.HasSuffix(c.AudioPathPrefix, "/") {
|
||||
return fmt.Errorf("audio_path_prefix must start and end with '/': %q", c.AudioPathPrefix)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumerForKind returns the consumer designated to handle a given endpoint
|
||||
// kind, or nil if routing is unset.
|
||||
func (c *Config) ConsumerForKind(kind EndpointKind) (string, *Consumer) {
|
||||
name, ok := c.Routing[kind]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return name, c.Consumers[name]
|
||||
}
|
||||
130
internal/config/config_test.go
Normal file
130
internal/config/config_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadValidConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "consumers.yaml")
|
||||
body := `
|
||||
listen: 127.0.0.1:8770
|
||||
gpu:
|
||||
total_mib: 16376
|
||||
reserved_mib: 1024
|
||||
poll_interval_seconds: 2
|
||||
|
||||
routing:
|
||||
tts: mvoice
|
||||
llm: ollama
|
||||
|
||||
audio_proxy: mvoice
|
||||
|
||||
consumers:
|
||||
mvoice:
|
||||
url: http://localhost:8766
|
||||
health:
|
||||
method: GET
|
||||
path: /api/health
|
||||
paths:
|
||||
tts:
|
||||
method: POST
|
||||
path: /api/synthesize
|
||||
vram_resident_mib: 2800
|
||||
unload:
|
||||
method: POST
|
||||
path: /api/admin/unload
|
||||
load:
|
||||
method: POST
|
||||
path: /api/admin/load
|
||||
can_coexist_with: [ollama]
|
||||
priority: 3
|
||||
max_concurrency: 1
|
||||
ollama:
|
||||
url: http://localhost:11434
|
||||
health:
|
||||
method: GET
|
||||
path: /api/tags
|
||||
paths:
|
||||
llm:
|
||||
method: POST
|
||||
path: /api/generate
|
||||
vram_managed: true
|
||||
can_coexist_with: [mvoice]
|
||||
priority: 2
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if cfg.Listen != "127.0.0.1:8770" {
|
||||
t.Errorf("Listen = %q", cfg.Listen)
|
||||
}
|
||||
if cfg.GPU.AvailableMiB() != 15352 {
|
||||
t.Errorf("AvailableMiB = %d, want 15352", cfg.GPU.AvailableMiB())
|
||||
}
|
||||
if cfg.GPU.PollInterval().Seconds() != 2 {
|
||||
t.Errorf("PollInterval = %s", cfg.GPU.PollInterval())
|
||||
}
|
||||
|
||||
name, cons := cfg.ConsumerForKind(KindTTS)
|
||||
if name != "mvoice" || cons == nil {
|
||||
t.Fatalf("ConsumerForKind(tts) = %q, %v", name, cons)
|
||||
}
|
||||
if cons.Paths[KindTTS].Method != "POST" {
|
||||
t.Errorf("default method not preserved")
|
||||
}
|
||||
if cons.MaxConcurrency != 1 {
|
||||
t.Errorf("MaxConcurrency = %d", cons.MaxConcurrency)
|
||||
}
|
||||
|
||||
if _, ok := cfg.Consumers["ollama"]; !ok {
|
||||
t.Fatal("ollama not loaded")
|
||||
}
|
||||
if cfg.Consumers["ollama"].MaxConcurrency != 1 {
|
||||
t.Errorf("ollama MaxConcurrency default = %d, want 1",
|
||||
cfg.Consumers["ollama"].MaxConcurrency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRejectsUnknownRouting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "consumers.yaml")
|
||||
body := `
|
||||
routing:
|
||||
tts: nonexistent
|
||||
consumers:
|
||||
mvoice:
|
||||
url: http://localhost:8766
|
||||
health: { path: /api/health }
|
||||
paths: {}
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := Load(path); err == nil {
|
||||
t.Fatal("expected error for unknown routing target, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRejectsMissingURL(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "consumers.yaml")
|
||||
body := `
|
||||
consumers:
|
||||
broken:
|
||||
health: { path: /h }
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := Load(path); err == nil {
|
||||
t.Fatal("expected error for missing URL, got nil")
|
||||
}
|
||||
}
|
||||
129
internal/gpu/gpu.go
Normal file
129
internal/gpu/gpu.go
Normal file
@@ -0,0 +1,129 @@
|
||||
// Package gpu polls nvidia-smi for live VRAM usage.
|
||||
//
|
||||
// Schritt 5 uses this to detect VRAM pressure and trigger LRU eviction.
|
||||
// On hosts without an NVIDIA GPU (e.g. m's laptop during local dev) the
|
||||
// poller silently reports zero usage so the scheduler can still run.
|
||||
package gpu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Sample is one nvidia-smi reading.
|
||||
type Sample struct {
|
||||
UsedMiB int
|
||||
FreeMiB int
|
||||
TotalMiB int
|
||||
At time.Time
|
||||
Err string
|
||||
}
|
||||
|
||||
// Poller periodically samples GPU memory and exposes the latest reading.
|
||||
type Poller struct {
|
||||
interval time.Duration
|
||||
logger *slog.Logger
|
||||
mu sync.RWMutex
|
||||
last Sample
|
||||
}
|
||||
|
||||
// NewPoller builds a Poller. Pass the desired sampling cadence.
|
||||
func NewPoller(interval time.Duration, logger *slog.Logger) *Poller {
|
||||
if interval <= 0 {
|
||||
interval = 2 * time.Second
|
||||
}
|
||||
return &Poller{interval: interval, logger: logger}
|
||||
}
|
||||
|
||||
// Run samples in a loop until ctx is cancelled.
|
||||
func (p *Poller) Run(ctx context.Context) {
|
||||
p.sampleOnce(ctx)
|
||||
t := time.NewTicker(p.interval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
p.sampleOnce(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Last returns the most recent sample.
|
||||
func (p *Poller) Last() Sample {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.last
|
||||
}
|
||||
|
||||
// SetSampleForTest injects a synthetic VRAM reading. Used from tests that
|
||||
// must drive the scheduler's eviction logic without a real GPU or
|
||||
// nvidia-smi. Production callers should never reach this.
|
||||
func SetSampleForTest(p *Poller, freeMiB, totalMiB int) {
|
||||
p.store(Sample{
|
||||
FreeMiB: freeMiB,
|
||||
TotalMiB: totalMiB,
|
||||
UsedMiB: totalMiB - freeMiB,
|
||||
At: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Poller) sampleOnce(ctx context.Context) {
|
||||
cctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// memory.used,memory.free,memory.total in MiB, no units, no header.
|
||||
cmd := exec.CommandContext(cctx, "nvidia-smi",
|
||||
"--query-gpu=memory.used,memory.free,memory.total",
|
||||
"--format=csv,noheader,nounits")
|
||||
out, err := cmd.Output()
|
||||
now := time.Now()
|
||||
if err != nil {
|
||||
p.store(Sample{At: now, Err: err.Error()})
|
||||
if p.logger != nil {
|
||||
p.logger.Debug("nvidia-smi failed", "err", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
used, free, total, perr := parseSMI(string(out))
|
||||
if perr != "" {
|
||||
p.store(Sample{At: now, Err: perr})
|
||||
return
|
||||
}
|
||||
p.store(Sample{UsedMiB: used, FreeMiB: free, TotalMiB: total, At: now})
|
||||
}
|
||||
|
||||
func (p *Poller) store(s Sample) {
|
||||
p.mu.Lock()
|
||||
p.last = s
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func parseSMI(out string) (used, free, total int, errMsg string) {
|
||||
// Take first non-empty line — multi-GPU hosts would yield more, but we
|
||||
// only support single-GPU (mRock) for Phase 1.
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(line, ",")
|
||||
if len(parts) != 3 {
|
||||
return 0, 0, 0, "unexpected nvidia-smi output: " + line
|
||||
}
|
||||
u, e1 := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
f, e2 := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
t, e3 := strconv.Atoi(strings.TrimSpace(parts[2]))
|
||||
if e1 != nil || e2 != nil || e3 != nil {
|
||||
return 0, 0, 0, "non-integer nvidia-smi output: " + line
|
||||
}
|
||||
return u, f, t, ""
|
||||
}
|
||||
return 0, 0, 0, "empty nvidia-smi output"
|
||||
}
|
||||
35
internal/registry/parse.go
Normal file
35
internal/registry/parse.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// parseGPUFields best-effort parses a consumer's health response body to pull
|
||||
// out gpu_resident_mib (mvoice exposes this since the Schritt 1 patch) and a
|
||||
// 'loaded' boolean. Returns the previous loaded value if the field is absent
|
||||
// (i.e. the consumer doesn't report it — assume always loaded).
|
||||
func parseGPUFields(body []byte, prevLoaded bool) (mib int, loaded bool) {
|
||||
var parsed struct {
|
||||
GPUResidentMiB *int `json:"gpu_resident_mib"`
|
||||
Loaded *bool `json:"loaded"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
// Non-JSON health response (e.g. whisper-server '/' returns HTML). Treat
|
||||
// 200 OK as healthy + loaded, no VRAM info.
|
||||
return 0, true
|
||||
}
|
||||
if parsed.GPUResidentMiB != nil {
|
||||
mib = *parsed.GPUResidentMiB
|
||||
}
|
||||
if parsed.Loaded != nil {
|
||||
loaded = *parsed.Loaded
|
||||
} else {
|
||||
// Field absent — preserve prior state, default true on first probe.
|
||||
if !prevLoaded {
|
||||
loaded = true
|
||||
} else {
|
||||
loaded = prevLoaded
|
||||
}
|
||||
}
|
||||
return mib, loaded
|
||||
}
|
||||
178
internal/registry/registry.go
Normal file
178
internal/registry/registry.go
Normal file
@@ -0,0 +1,178 @@
|
||||
// Package registry tracks the live state of every GPU consumer.
|
||||
//
|
||||
// At Schritt 2 (MVP) the registry only does health probing — periodic GET on
|
||||
// each consumer's health route, last-success timestamp, last error. Schritt 4
|
||||
// adds per-consumer in-flight counts and LastUsed for LRU eviction in
|
||||
// Schritt 5.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
)
|
||||
|
||||
// State is a snapshot of a single consumer's live status.
|
||||
type State struct {
|
||||
Name string
|
||||
Healthy bool
|
||||
LastProbe time.Time
|
||||
LastError string
|
||||
GPUResidentMiB int // populated from consumer health response when present
|
||||
Loaded bool // mvoice reports this; others default to true
|
||||
Active int // in-flight job count (Schritt 4)
|
||||
LastUsed time.Time // last successful job completion (Schritt 5)
|
||||
TotalRequests int64
|
||||
}
|
||||
|
||||
// Registry holds the live state of all consumers.
|
||||
type Registry struct {
|
||||
cfg *config.Config
|
||||
client *http.Client
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
states map[string]*State
|
||||
}
|
||||
|
||||
// New builds a Registry from the loaded config.
|
||||
func New(cfg *config.Config, logger *slog.Logger) *Registry {
|
||||
r := &Registry{
|
||||
cfg: cfg,
|
||||
client: &http.Client{Timeout: 5 * time.Second},
|
||||
logger: logger,
|
||||
states: make(map[string]*State, len(cfg.Consumers)),
|
||||
}
|
||||
for name := range cfg.Consumers {
|
||||
r.states[name] = &State{Name: name}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Run starts the background health-probe loop and blocks until ctx is done.
|
||||
// Cadence is fixed at 5s for health (independent of GPU polling cadence).
|
||||
func (r *Registry) Run(ctx context.Context) {
|
||||
r.probeAll(ctx)
|
||||
t := time.NewTicker(5 * time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
r.probeAll(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) probeAll(ctx context.Context) {
|
||||
var wg sync.WaitGroup
|
||||
for name, cons := range r.cfg.Consumers {
|
||||
wg.Add(1)
|
||||
go func(name string, cons *config.Consumer) {
|
||||
defer wg.Done()
|
||||
r.probeOne(ctx, name, cons)
|
||||
}(name, cons)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *Registry) probeOne(ctx context.Context, name string, cons *config.Consumer) {
|
||||
cctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(cctx, cons.Health.Method, cons.URL+cons.Health.Path, nil)
|
||||
if err != nil {
|
||||
r.recordProbe(name, false, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
r.recordProbe(name, false, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if resp.StatusCode >= 400 {
|
||||
r.recordProbe(name, false, fmt.Sprintf("status %d", resp.StatusCode), nil)
|
||||
return
|
||||
}
|
||||
r.recordProbe(name, true, "", body)
|
||||
}
|
||||
|
||||
// recordProbe stores the outcome of one health check, optionally parsing
|
||||
// gpu_resident_mib / loaded fields out of the response body.
|
||||
func (r *Registry) recordProbe(name string, ok bool, errMsg string, body []byte) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
s := r.states[name]
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.LastProbe = time.Now()
|
||||
s.Healthy = ok
|
||||
s.LastError = errMsg
|
||||
if ok && body != nil {
|
||||
s.GPUResidentMiB, s.Loaded = parseGPUFields(body, s.Loaded)
|
||||
}
|
||||
if !ok && r.logger != nil {
|
||||
r.logger.Debug("consumer probe failed", "consumer", name, "err", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordProbeForTest exposes the internal probe-recording path to tests
|
||||
// in other packages without depending on the live 5s probe loop.
|
||||
func (r *Registry) RecordProbeForTest(name string, ok bool, errMsg string, body []byte) {
|
||||
r.recordProbe(name, ok, errMsg, body)
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of all consumer states, ordered by config-declared
|
||||
// consumer name set (Go map iteration order is randomized — callers that need
|
||||
// stable ordering should sort).
|
||||
func (r *Registry) Snapshot() map[string]State {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make(map[string]State, len(r.states))
|
||||
for k, v := range r.states {
|
||||
out[k] = *v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Get returns a single consumer state (copy) or zero-value if unknown.
|
||||
func (r *Registry) Get(name string) State {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
if s, ok := r.states[name]; ok {
|
||||
return *s
|
||||
}
|
||||
return State{}
|
||||
}
|
||||
|
||||
// MarkActive increments the in-flight count and updates LastUsed.
|
||||
// Returns a release func to call on job completion.
|
||||
func (r *Registry) MarkActive(name string) func() {
|
||||
r.mu.Lock()
|
||||
if s, ok := r.states[name]; ok {
|
||||
s.Active++
|
||||
s.TotalRequests++
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return func() {
|
||||
r.mu.Lock()
|
||||
if s, ok := r.states[name]; ok {
|
||||
if s.Active > 0 {
|
||||
s.Active--
|
||||
}
|
||||
s.LastUsed = time.Now()
|
||||
}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
}
|
||||
329
internal/scheduler/evicting.go
Normal file
329
internal/scheduler/evicting.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/gpu"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
// vramCushionMiB is the minimum free VRAM the scheduler insists on having
|
||||
// AFTER the target consumer is loaded. Keeps cudaMalloc headers from OOM-ing
|
||||
// at the very edge of available memory.
|
||||
const vramCushionMiB = 256
|
||||
|
||||
// maxEvictAttempts caps how many consumers the scheduler will unload in a
|
||||
// single ensureFits cycle before giving up and returning an error. Five is
|
||||
// generous — we only have four consumers configured.
|
||||
const maxEvictAttempts = 5
|
||||
|
||||
// Evicting is the Schritt 5 scheduler: it wraps a Locked scheduler with
|
||||
// VRAM-pressure-aware eviction.
|
||||
//
|
||||
// Flow per job:
|
||||
// 1. ensureFits — if the live free VRAM minus a 256 MiB cushion is below
|
||||
// the target consumer's vram_resident_mib AND the target is not already
|
||||
// resident, unload the LRU non-coexistent consumer. Repeat until fit.
|
||||
// 2. ensureLoaded — if the target was previously unloaded, call its
|
||||
// load endpoint (mvoice) or rely on implicit cold-start (whisper, etc.).
|
||||
// 3. inner.Run — acquire the global GPU lock and run the job.
|
||||
//
|
||||
// Eviction state is scheduler-local: registry.Loaded (polled every 5 s) is
|
||||
// authoritative when the consumer reports it, but for the seconds between an
|
||||
// unload call and the next probe we rely on our own bookkeeping.
|
||||
type Evicting struct {
|
||||
cfg *config.Config
|
||||
reg *registry.Registry
|
||||
gpu *gpu.Poller
|
||||
inner *Locked
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
loaded map[string]bool // consumer name -> believed-resident
|
||||
lastUsed map[string]time.Time
|
||||
evictions int64
|
||||
}
|
||||
|
||||
// NewEvicting builds the Schritt 5 scheduler. All consumers are assumed
|
||||
// resident at startup — the first health probe will correct any consumers
|
||||
// that actually aren't (e.g. mvoice in 'unloaded' state).
|
||||
func NewEvicting(cfg *config.Config, reg *registry.Registry, gpuPoller *gpu.Poller, logger *slog.Logger) *Evicting {
|
||||
e := &Evicting{
|
||||
cfg: cfg,
|
||||
reg: reg,
|
||||
gpu: gpuPoller,
|
||||
inner: NewLocked(reg, 1),
|
||||
logger: logger,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
loaded: make(map[string]bool, len(cfg.Consumers)),
|
||||
lastUsed: make(map[string]time.Time, len(cfg.Consumers)),
|
||||
}
|
||||
for name, cons := range cfg.Consumers {
|
||||
// Self-managed VRAM consumers (ollama) are always 'loaded' from
|
||||
// the scheduler's perspective — we never evict them via HTTP.
|
||||
e.loaded[name] = !cons.VRAMManaged || true
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Run is the public Scheduler interface: ensure room + load + serialise.
|
||||
func (e *Evicting) Run(ctx context.Context, consumer string, fn Job) error {
|
||||
if err := e.ensureFits(ctx, consumer); err != nil {
|
||||
return fmt.Errorf("eviction: %w", err)
|
||||
}
|
||||
if err := e.ensureLoaded(ctx, consumer); err != nil {
|
||||
return fmt.Errorf("load %s: %w", consumer, err)
|
||||
}
|
||||
err := e.inner.Run(ctx, consumer, fn)
|
||||
if err == nil {
|
||||
e.mu.Lock()
|
||||
e.lastUsed[consumer] = time.Now()
|
||||
e.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Stats forwards from the inner scheduler and adds the eviction counter.
|
||||
func (e *Evicting) Stats() Stats {
|
||||
s := e.inner.Stats()
|
||||
s.Evictions = atomic.LoadInt64(&e.evictions)
|
||||
return s
|
||||
}
|
||||
|
||||
// ───── ensureFits ────────────────────────────────────────────────────────
|
||||
|
||||
func (e *Evicting) ensureFits(ctx context.Context, target string) error {
|
||||
cons := e.cfg.Consumers[target]
|
||||
if cons == nil {
|
||||
return fmt.Errorf("unknown consumer %q", target)
|
||||
}
|
||||
if cons.VRAMResidentMiB == 0 || cons.VRAMManaged {
|
||||
// Self-managed (ollama) or unknown size — let the consumer figure
|
||||
// it out; no preemptive eviction.
|
||||
return nil
|
||||
}
|
||||
// Already resident? No eviction needed.
|
||||
e.mu.Lock()
|
||||
resident := e.loaded[target]
|
||||
e.mu.Unlock()
|
||||
if resident {
|
||||
return nil
|
||||
}
|
||||
|
||||
for range maxEvictAttempts {
|
||||
if e.fits(cons) {
|
||||
return nil
|
||||
}
|
||||
victim := e.pickLRUVictim(target, cons)
|
||||
if victim == "" {
|
||||
// Nothing left to evict that we're allowed to touch.
|
||||
e.logger.Warn("no eviction candidates", "target", target,
|
||||
"need_mib", cons.VRAMResidentMiB,
|
||||
"free_mib", e.gpu.Last().FreeMiB)
|
||||
return nil
|
||||
}
|
||||
if err := e.unload(ctx, victim); err != nil {
|
||||
e.logger.Warn("evict failed", "victim", victim, "err", err)
|
||||
return fmt.Errorf("unload %s: %w", victim, err)
|
||||
}
|
||||
atomic.AddInt64(&e.evictions, 1)
|
||||
e.logger.Info("evicted consumer",
|
||||
"victim", victim, "target", target,
|
||||
"free_mib_after", e.gpu.Last().FreeMiB,
|
||||
"need_mib", cons.VRAMResidentMiB)
|
||||
// Give the GPU a moment to actually free the VRAM before re-checking.
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("VRAM headroom still insufficient after %d evictions", maxEvictAttempts)
|
||||
}
|
||||
|
||||
// fits returns true when the live nvidia-smi free VRAM minus the safety
|
||||
// cushion is enough for the target consumer's predicted footprint.
|
||||
//
|
||||
// Falls back to the static budget (cfg.GPU.AvailableMiB() minus the
|
||||
// non-coexistent loaded set) if the GPU poller has not produced a sample
|
||||
// yet (e.g. during the first second of process lifetime).
|
||||
func (e *Evicting) fits(cons *config.Consumer) bool {
|
||||
sample := e.gpu.Last()
|
||||
if sample.FreeMiB > 0 || sample.TotalMiB > 0 {
|
||||
return sample.FreeMiB >= cons.VRAMResidentMiB+vramCushionMiB
|
||||
}
|
||||
return e.fitsByBudget(cons)
|
||||
}
|
||||
|
||||
func (e *Evicting) fitsByBudget(cons *config.Consumer) bool {
|
||||
headroom := e.cfg.GPU.AvailableMiB()
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for name, loaded := range e.loaded {
|
||||
if !loaded {
|
||||
continue
|
||||
}
|
||||
other := e.cfg.Consumers[name]
|
||||
if other == nil || other.VRAMManaged {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(cons.CanCoexistWith, name) {
|
||||
continue
|
||||
}
|
||||
headroom -= other.VRAMResidentMiB
|
||||
}
|
||||
return headroom >= cons.VRAMResidentMiB
|
||||
}
|
||||
|
||||
// pickLRUVictim returns the name of the loaded consumer with the oldest
|
||||
// LastUsed that is NOT in target's can_coexist_with list, NOT the target
|
||||
// itself, NOT VRAM-managed, and has *some* way to be evicted.
|
||||
func (e *Evicting) pickLRUVictim(target string, cons *config.Consumer) string {
|
||||
snap := e.reg.Snapshot()
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
var best string
|
||||
var bestTime time.Time
|
||||
for name, loaded := range e.loaded {
|
||||
if !loaded || name == target {
|
||||
continue
|
||||
}
|
||||
other := e.cfg.Consumers[name]
|
||||
if other == nil || other.VRAMManaged {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(cons.CanCoexistWith, name) {
|
||||
continue
|
||||
}
|
||||
if other.Unload == nil && other.SystemdUnit == "" {
|
||||
continue
|
||||
}
|
||||
// LastUsed: prefer scheduler-local (set on successful job exit) over
|
||||
// registry (set on probe completion). Scheduler-local is more
|
||||
// meaningful for LRU because it reflects real GPU work, not health
|
||||
// chatter.
|
||||
t := e.lastUsed[name]
|
||||
if t.IsZero() {
|
||||
t = snap[name].LastUsed
|
||||
}
|
||||
if best == "" || t.Before(bestTime) {
|
||||
best = name
|
||||
bestTime = t
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
// ───── unload + load ─────────────────────────────────────────────────────
|
||||
|
||||
func (e *Evicting) unload(ctx context.Context, name string) error {
|
||||
cons := e.cfg.Consumers[name]
|
||||
if cons.Unload == nil {
|
||||
// systemd-unit-based unload is whisper-server's path; we don't shell
|
||||
// out to sudo from a server daemon in Phase 1. Mark unloaded so we
|
||||
// don't keep picking it as a victim, and let the next request
|
||||
// cold-start via systemd (whisper-server boots in <2 s).
|
||||
if cons.SystemdUnit != "" {
|
||||
e.mu.Lock()
|
||||
e.loaded[name] = false
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("consumer %s: no unload route configured", name)
|
||||
}
|
||||
|
||||
url := cons.URL + cons.Unload.Path
|
||||
var body io.Reader
|
||||
if cons.Unload.Body != "" {
|
||||
body = strings.NewReader(cons.Unload.Body)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, cons.Unload.Method, url, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cons.Unload.Body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
resp, err := e.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("unload %s returned status %d", name, resp.StatusCode)
|
||||
}
|
||||
e.mu.Lock()
|
||||
e.loaded[name] = false
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Evicting) ensureLoaded(ctx context.Context, name string) error {
|
||||
cons := e.cfg.Consumers[name]
|
||||
if cons == nil {
|
||||
return fmt.Errorf("unknown consumer %q", name)
|
||||
}
|
||||
e.mu.Lock()
|
||||
if e.loaded[name] {
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
// No explicit load endpoint — rely on the consumer's own cold-start
|
||||
// behaviour (mvoice would auto-load if a request arrived, comfyui as
|
||||
// well). Mark loaded optimistically.
|
||||
if cons.Load == nil {
|
||||
e.mu.Lock()
|
||||
e.loaded[name] = true
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
url := cons.URL + cons.Load.Path
|
||||
var body io.Reader
|
||||
if cons.Load.Body != "" {
|
||||
body = strings.NewReader(cons.Load.Body)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, cons.Load.Method, url, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := e.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("load %s returned status %d", name, resp.StatusCode)
|
||||
}
|
||||
e.mu.Lock()
|
||||
e.loaded[name] = true
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLoadedForTest overrides the believed-loaded state for one consumer.
|
||||
// Test-only — production code derives it from health probes + unload calls.
|
||||
func (e *Evicting) SetLoadedForTest(name string, loaded bool) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.loaded[name] = loaded
|
||||
}
|
||||
|
||||
// Compile-time interface guard.
|
||||
var _ Scheduler = (*Evicting)(nil)
|
||||
247
internal/scheduler/evicting_test.go
Normal file
247
internal/scheduler/evicting_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/gpu"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
func silentLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
}
|
||||
|
||||
// gpuStub implements just enough of gpu.Poller's surface for the evicting
|
||||
// scheduler. We use the real Poller type (no interface yet) by hand-loading
|
||||
// a Sample via a tiny wrapper.
|
||||
//
|
||||
// In practice we set gpu.Poller's internal sample via NewPoller + a goroutine.
|
||||
// For tests we sidestep that by using a real Poller with a fake nvidia-smi —
|
||||
// but the simpler path is to construct a Poller, store a Sample, and skip
|
||||
// Run. We do that by exposing a tiny helper here.
|
||||
|
||||
// makeGPU returns a Poller pre-loaded with the given free/total values.
|
||||
// It never calls nvidia-smi.
|
||||
func makeGPU(t *testing.T, freeMiB, totalMiB int) *gpu.Poller {
|
||||
t.Helper()
|
||||
p := gpu.NewPoller(time.Hour, silentLogger())
|
||||
// gpu.Poller.Last() reads from an internal Sample. We can't poke it
|
||||
// directly without exporting state, so we use a sub-test trick: run
|
||||
// sampleOnce against a fake nvidia-smi command. But that needs a PATH
|
||||
// override and is brittle. Instead, expose a SetForTest helper.
|
||||
gpu.SetSampleForTest(p, freeMiB, totalMiB)
|
||||
return p
|
||||
}
|
||||
|
||||
// fakeConsumer hosts /api/admin/{load,unload} so the evicting scheduler can
|
||||
// exercise the HTTP eviction path.
|
||||
type fakeConsumer struct {
|
||||
srv *httptest.Server
|
||||
unloadHit atomic.Int32
|
||||
loadHit atomic.Int32
|
||||
}
|
||||
|
||||
func newFakeConsumer(t *testing.T) *fakeConsumer {
|
||||
t.Helper()
|
||||
fc := &fakeConsumer{}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"loaded":true,"gpu_resident_mib":2800}`))
|
||||
})
|
||||
mux.HandleFunc("POST /api/admin/unload", func(w http.ResponseWriter, _ *http.Request) {
|
||||
fc.unloadHit.Add(1)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
mux.HandleFunc("POST /api/admin/load", func(w http.ResponseWriter, _ *http.Request) {
|
||||
fc.loadHit.Add(1)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
mux.HandleFunc("POST /prompt", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
mux.HandleFunc("POST /api/free", func(w http.ResponseWriter, _ *http.Request) {
|
||||
fc.unloadHit.Add(1)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
fc.srv = httptest.NewServer(mux)
|
||||
return fc
|
||||
}
|
||||
|
||||
func buildCfg(mvoiceURL, comfyURL string) *config.Config {
|
||||
return &config.Config{
|
||||
Listen: "127.0.0.1:0",
|
||||
GPU: config.GPU{TotalMiB: 16376, ReservedMiB: 1024, PollIntervalSeconds: 2},
|
||||
Routing: map[config.EndpointKind]string{
|
||||
config.KindTTS: "mvoice",
|
||||
config.KindImage: "comfyui",
|
||||
},
|
||||
Consumers: map[string]*config.Consumer{
|
||||
"mvoice": {
|
||||
URL: mvoiceURL,
|
||||
Health: config.Route{Method: "GET", Path: "/api/health"},
|
||||
Paths: map[config.EndpointKind]config.Route{
|
||||
config.KindTTS: {Method: "POST", Path: "/api/synthesize"},
|
||||
},
|
||||
VRAMResidentMiB: 2800,
|
||||
Load: &config.Route{Method: "POST", Path: "/api/admin/load"},
|
||||
Unload: &config.Route{Method: "POST", Path: "/api/admin/unload"},
|
||||
CanCoexistWith: []string{"whisper-server", "ollama"},
|
||||
Priority: 3,
|
||||
MaxConcurrency: 1,
|
||||
},
|
||||
"comfyui": {
|
||||
URL: comfyURL,
|
||||
Health: config.Route{Method: "GET", Path: "/system_stats"},
|
||||
Paths: map[config.EndpointKind]config.Route{
|
||||
config.KindImage: {Method: "POST", Path: "/prompt"},
|
||||
},
|
||||
VRAMResidentMiB: 13000,
|
||||
Unload: &config.Route{
|
||||
Method: "POST",
|
||||
Path: "/api/free",
|
||||
Body: `{"unload_models":true,"free_memory":true}`,
|
||||
},
|
||||
CanCoexistWith: []string{},
|
||||
Priority: 1,
|
||||
MaxConcurrency: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvictingSkipsWhenAlreadyResident verifies the no-op fast path: a job
|
||||
// for an already-loaded consumer with plenty of free VRAM runs without any
|
||||
// unload call.
|
||||
func TestEvictingSkipsWhenAlreadyResident(t *testing.T) {
|
||||
mvoice := newFakeConsumer(t)
|
||||
defer mvoice.srv.Close()
|
||||
comfy := newFakeConsumer(t)
|
||||
defer comfy.srv.Close()
|
||||
|
||||
cfg := buildCfg(mvoice.srv.URL, comfy.srv.URL)
|
||||
reg := registry.New(cfg, silentLogger())
|
||||
g := makeGPU(t, 8192, 16376) // plenty of headroom
|
||||
e := NewEvicting(cfg, reg, g, silentLogger())
|
||||
|
||||
if err := e.Run(context.Background(), "mvoice", func(ctx context.Context) error { return nil }); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if mvoice.unloadHit.Load() != 0 {
|
||||
t.Errorf("unexpected unload hits on mvoice: %d", mvoice.unloadHit.Load())
|
||||
}
|
||||
if comfy.unloadHit.Load() != 0 {
|
||||
t.Errorf("unexpected unload hits on comfyui: %d", comfy.unloadHit.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvictingFreesNonCoexistentVictim simulates the canonical scenario from
|
||||
// the design: a TTS request comes in while comfyui is hogging 13 GiB. mvoice
|
||||
// is not coexistent with comfyui (per cfg), so the scheduler must call
|
||||
// comfyui's /api/free before letting the TTS job run.
|
||||
func TestEvictingFreesNonCoexistentVictim(t *testing.T) {
|
||||
mvoice := newFakeConsumer(t)
|
||||
defer mvoice.srv.Close()
|
||||
comfy := newFakeConsumer(t)
|
||||
defer comfy.srv.Close()
|
||||
|
||||
cfg := buildCfg(mvoice.srv.URL, comfy.srv.URL)
|
||||
reg := registry.New(cfg, silentLogger())
|
||||
|
||||
// Only 1 GiB free — mvoice (2.8 GiB) won't fit until comfyui (13 GiB)
|
||||
// is evicted.
|
||||
g := makeGPU(t, 1024, 16376)
|
||||
e := NewEvicting(cfg, reg, g, silentLogger())
|
||||
|
||||
// Force the believed-loaded state so eviction kicks in (Run treats
|
||||
// 'already loaded' as a no-op fast path).
|
||||
e.SetLoadedForTest("mvoice", false)
|
||||
e.SetLoadedForTest("comfyui", true)
|
||||
|
||||
// After the eviction unload call lands, we want fits() to return true
|
||||
// for the next iteration — patch the GPU sample to reflect the freed
|
||||
// memory by swapping the poller before the second fits() check is hit.
|
||||
// We accomplish that by stubbing the unload handler to also bump the
|
||||
// sample.
|
||||
comfy.srv.Config.Handler = withHook(comfy.srv.Config.Handler, func() {
|
||||
gpu.SetSampleForTest(g, 14000, 16376)
|
||||
})
|
||||
|
||||
if err := e.Run(context.Background(), "mvoice", func(ctx context.Context) error { return nil }); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := comfy.unloadHit.Load(); got != 1 {
|
||||
t.Errorf("comfyui unload hit count = %d, want 1", got)
|
||||
}
|
||||
if got := mvoice.loadHit.Load(); got != 1 {
|
||||
t.Errorf("mvoice load hit count = %d, want 1", got)
|
||||
}
|
||||
if got := e.Stats().Evictions; got != 1 {
|
||||
t.Errorf("stats.Evictions = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvictingHonoursCoexistence ensures we never evict a consumer that the
|
||||
// target declared compatible. mvoice can coexist with ollama, so ollama must
|
||||
// not be picked even if it's the LRU candidate.
|
||||
func TestEvictingHonoursCoexistence(t *testing.T) {
|
||||
mvoice := newFakeConsumer(t)
|
||||
defer mvoice.srv.Close()
|
||||
comfy := newFakeConsumer(t)
|
||||
defer comfy.srv.Close()
|
||||
|
||||
cfg := buildCfg(mvoice.srv.URL, comfy.srv.URL)
|
||||
// Add a stub ollama with an unload endpoint, mark coexistent.
|
||||
ollama := newFakeConsumer(t)
|
||||
defer ollama.srv.Close()
|
||||
cfg.Consumers["ollama"] = &config.Consumer{
|
||||
URL: ollama.srv.URL,
|
||||
Health: config.Route{Method: "GET", Path: "/api/health"},
|
||||
Paths: map[config.EndpointKind]config.Route{},
|
||||
VRAMResidentMiB: 2000,
|
||||
Unload: &config.Route{Method: "POST", Path: "/api/admin/unload"},
|
||||
CanCoexistWith: []string{"mvoice"},
|
||||
MaxConcurrency: 1,
|
||||
}
|
||||
|
||||
reg := registry.New(cfg, silentLogger())
|
||||
g := makeGPU(t, 1000, 16376)
|
||||
e := NewEvicting(cfg, reg, g, silentLogger())
|
||||
e.SetLoadedForTest("mvoice", false)
|
||||
e.SetLoadedForTest("comfyui", true)
|
||||
e.SetLoadedForTest("ollama", true)
|
||||
|
||||
comfy.srv.Config.Handler = withHook(comfy.srv.Config.Handler, func() {
|
||||
gpu.SetSampleForTest(g, 14000, 16376)
|
||||
})
|
||||
|
||||
if err := e.Run(context.Background(), "mvoice", func(ctx context.Context) error { return nil }); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := ollama.unloadHit.Load(); got != 0 {
|
||||
t.Errorf("ollama (coexistent) unloaded %d times; should be 0", got)
|
||||
}
|
||||
if got := comfy.unloadHit.Load(); got != 1 {
|
||||
t.Errorf("comfyui unload hit count = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ───── helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
// withHook wraps an http.Handler so each call invokes hook() before
|
||||
// delegating to the original handler. Used to simulate VRAM being freed
|
||||
// the instant comfyui's /api/free returns.
|
||||
func withHook(h http.Handler, hook func()) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hook()
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
124
internal/scheduler/locked.go
Normal file
124
internal/scheduler/locked.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
// Locked is the Schritt 4 scheduler: a single capacity-1 semaphore serialises
|
||||
// every consumer's GPU work. Jobs wait FIFO-ish at the channel until the lock
|
||||
// is available, then run to completion.
|
||||
//
|
||||
// Why one global lock instead of per-stream or per-consumer:
|
||||
// - mRock is single-GPU + single-user. Theoretical parallelism (e.g. mvoice
|
||||
// + ollama small model coexisting) is given up to gain predictability:
|
||||
// no more CUDA-OOM races between concurrent loaders.
|
||||
// - The design (§4.3) is explicit: "Der Lock ist grobgranular (ein Mutex)
|
||||
// […]. Wir verschenken theoretische Parallelität, gewinnen dafür
|
||||
// Vorhersagbarkeit."
|
||||
//
|
||||
// Schritt 5 wraps this with eviction logic that runs before sem acquire when
|
||||
// the requested consumer's resident cost would exceed available headroom.
|
||||
type Locked struct {
|
||||
reg *registry.Registry
|
||||
gpuLock chan struct{} // capacity-1 = global mutex with cancellable acquire
|
||||
|
||||
mu sync.Mutex
|
||||
inFlight int
|
||||
queueDepth int
|
||||
total int64
|
||||
lastWaitMS int64
|
||||
lastRunMS int64
|
||||
oldestQueued time.Time
|
||||
}
|
||||
|
||||
// NewLocked returns the serialising scheduler. capacity is the number of
|
||||
// concurrent jobs allowed on the GPU (Phase 1 wires this as 1).
|
||||
func NewLocked(reg *registry.Registry, capacity int) *Locked {
|
||||
if capacity < 1 {
|
||||
capacity = 1
|
||||
}
|
||||
return &Locked{
|
||||
reg: reg,
|
||||
gpuLock: make(chan struct{}, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
// Run acquires the global GPU lock, executes fn while holding it, and
|
||||
// releases. Cancellation via ctx aborts the wait without leaking a token.
|
||||
func (s *Locked) Run(ctx context.Context, consumer string, fn Job) error {
|
||||
release := s.reg.MarkActive(consumer)
|
||||
defer release()
|
||||
|
||||
queuedAt := time.Now()
|
||||
s.mu.Lock()
|
||||
s.queueDepth++
|
||||
if s.queueDepth == 1 || s.oldestQueued.IsZero() {
|
||||
s.oldestQueued = queuedAt
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Acquire global lock or bail on cancellation.
|
||||
select {
|
||||
case s.gpuLock <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
s.mu.Lock()
|
||||
s.queueDepth--
|
||||
if s.queueDepth == 0 {
|
||||
s.oldestQueued = time.Time{}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return ctx.Err()
|
||||
}
|
||||
waitMS := time.Since(queuedAt).Milliseconds()
|
||||
|
||||
s.mu.Lock()
|
||||
s.queueDepth--
|
||||
if s.queueDepth == 0 {
|
||||
s.oldestQueued = time.Time{}
|
||||
}
|
||||
s.inFlight++
|
||||
s.total++
|
||||
s.lastWaitMS = waitMS
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
<-s.gpuLock
|
||||
s.mu.Lock()
|
||||
s.inFlight--
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
err := fn(ctx)
|
||||
runMS := time.Since(start).Milliseconds()
|
||||
s.mu.Lock()
|
||||
s.lastRunMS = runMS
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Stats reports current depth + last timings for /v1/status.
|
||||
func (s *Locked) Stats() Stats {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return Stats{
|
||||
QueueDepth: s.queueDepth,
|
||||
InFlight: s.inFlight,
|
||||
TotalJobs: s.total,
|
||||
LastWaitMS: s.lastWaitMS,
|
||||
LastRunMS: s.lastRunMS,
|
||||
OldestQueued: s.oldestQueued,
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time interface guard.
|
||||
var _ Scheduler = (*Locked)(nil)
|
||||
|
||||
// Unused import guard — keeps the config package edge live for Schritt 5's
|
||||
// VRAM-pressure evaluation, which reads cfg in this same package.
|
||||
var _ = config.KindTTS
|
||||
164
internal/scheduler/locked_test.go
Normal file
164
internal/scheduler/locked_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
func newReg() *registry.Registry {
|
||||
cfg := &config.Config{
|
||||
Consumers: map[string]*config.Consumer{
|
||||
"mvoice": {
|
||||
URL: "http://localhost:8766",
|
||||
Health: config.Route{Method: "GET", Path: "/api/health"},
|
||||
},
|
||||
},
|
||||
}
|
||||
return registry.New(cfg, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
}
|
||||
|
||||
// TestLockedSerialisesConcurrentJobs is the regression test for the
|
||||
// CUDA-OOM-from-parallel-loaders class: two TTS calls that arrive at the
|
||||
// same time must run sequentially, not concurrently.
|
||||
func TestLockedSerialisesConcurrentJobs(t *testing.T) {
|
||||
sched := NewLocked(newReg(), 1)
|
||||
|
||||
var maxConcurrent atomic.Int32
|
||||
var inFlight atomic.Int32
|
||||
|
||||
job := func(ctx context.Context) error {
|
||||
now := inFlight.Add(1)
|
||||
// Update max in a CAS loop (small N, never contested in practice).
|
||||
for {
|
||||
cur := maxConcurrent.Load()
|
||||
if now <= cur || maxConcurrent.CompareAndSwap(cur, now) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
inFlight.Add(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const n = 5
|
||||
for range n {
|
||||
wg.Go(func() {
|
||||
if err := sched.Run(context.Background(), "mvoice", job); err != nil {
|
||||
t.Errorf("Run: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := maxConcurrent.Load(); got != 1 {
|
||||
t.Fatalf("max concurrent jobs = %d, want 1", got)
|
||||
}
|
||||
if got := sched.Stats().TotalJobs; got != n {
|
||||
t.Errorf("Stats.TotalJobs = %d, want %d", got, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockedRespectsContextCancel(t *testing.T) {
|
||||
sched := NewLocked(newReg(), 1)
|
||||
|
||||
// Hold the lock with a long-running job.
|
||||
started := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
|
||||
close(started)
|
||||
<-done
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
<-started
|
||||
|
||||
// Now try to run with a context that we'll cancel.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- sched.Run(ctx, "mvoice", func(ctx context.Context) error {
|
||||
t.Error("second job should not run after cancellation")
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
// Give the second job time to queue.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("got err=%v, want context.Canceled", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("cancelled Run did not return within 1s")
|
||||
}
|
||||
|
||||
// Release the holder.
|
||||
close(done)
|
||||
}
|
||||
|
||||
func TestLockedStatsTrackInFlightAndQueue(t *testing.T) {
|
||||
sched := NewLocked(newReg(), 1)
|
||||
|
||||
jobStart := make(chan struct{})
|
||||
jobBlock := make(chan struct{})
|
||||
go func() {
|
||||
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
|
||||
close(jobStart)
|
||||
<-jobBlock
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
<-jobStart
|
||||
|
||||
// Inside the holding job: InFlight==1, QueueDepth==0.
|
||||
s := sched.Stats()
|
||||
if s.InFlight != 1 {
|
||||
t.Errorf("InFlight while holding = %d, want 1", s.InFlight)
|
||||
}
|
||||
if s.QueueDepth != 0 {
|
||||
t.Errorf("QueueDepth = %d, want 0", s.QueueDepth)
|
||||
}
|
||||
|
||||
// Queue a waiter and verify QueueDepth grows.
|
||||
waitStarted := make(chan struct{})
|
||||
waitDone := make(chan struct{})
|
||||
go func() {
|
||||
close(waitStarted)
|
||||
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
close(waitDone)
|
||||
}()
|
||||
<-waitStarted
|
||||
// Wait for the waiter to actually be parked on the channel.
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if sched.Stats().QueueDepth == 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
if got := sched.Stats().QueueDepth; got != 1 {
|
||||
t.Errorf("QueueDepth with one waiter = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(jobBlock)
|
||||
<-waitDone
|
||||
if got := sched.Stats().InFlight; got != 0 {
|
||||
t.Errorf("InFlight after both done = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
112
internal/scheduler/scheduler.go
Normal file
112
internal/scheduler/scheduler.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package scheduler controls who gets the GPU when.
|
||||
//
|
||||
// Three responsibilities, added in three phases:
|
||||
//
|
||||
// - Schritt 2 (this file's first version): a passthrough — every job runs
|
||||
// immediately, no locking, no queueing. Only useful for proving the HTTP
|
||||
// façade end-to-end.
|
||||
// - Schritt 4: a global mutex (or capacity-1 channel) serialises all GPU
|
||||
// work. Per-consumer max_concurrency limits stay at 1 for now.
|
||||
// - Schritt 5: VRAM-pressure-aware eviction kicks in before acquire when the
|
||||
// requested consumer's resident cost would exceed available headroom.
|
||||
//
|
||||
// The interface deliberately hides which phase is active from callers
|
||||
// (server.go) so the upgrade path is local to this package.
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
// ErrSchedulerStopped is returned if Run is called after Close.
|
||||
var ErrSchedulerStopped = errors.New("scheduler stopped")
|
||||
|
||||
// Job is what a consumer route worker executes while holding the GPU lock.
|
||||
type Job func(ctx context.Context) error
|
||||
|
||||
// Scheduler decides when GPU work runs. Implementations may queue, serialise,
|
||||
// or evict other consumers before granting access.
|
||||
type Scheduler interface {
|
||||
// Run executes fn while the caller holds the right to use the GPU for
|
||||
// the named consumer. It blocks until fn returns or ctx is cancelled.
|
||||
Run(ctx context.Context, consumer string, fn Job) error
|
||||
|
||||
// Stats returns a snapshot of scheduler internals for /v1/status.
|
||||
Stats() Stats
|
||||
}
|
||||
|
||||
// Stats is what /v1/status reports about the scheduler.
|
||||
type Stats struct {
|
||||
QueueDepth int `json:"queue_depth"`
|
||||
InFlight int `json:"in_flight"`
|
||||
TotalJobs int64 `json:"total_jobs"`
|
||||
LastWaitMS int64 `json:"last_wait_ms"`
|
||||
LastRunMS int64 `json:"last_run_ms"`
|
||||
Evictions int64 `json:"evictions"`
|
||||
OldestQueued time.Time `json:"oldest_queued,omitzero"`
|
||||
}
|
||||
|
||||
// Passthrough is the Schritt 2 stand-in: no lock, no queue. Every job runs
|
||||
// concurrently. It exists so the server package can be written against the
|
||||
// final interface from day one.
|
||||
type Passthrough struct {
|
||||
reg *registry.Registry
|
||||
|
||||
mu sync.Mutex
|
||||
inFlight int
|
||||
total int64
|
||||
lastRunMS int64
|
||||
}
|
||||
|
||||
// NewPassthrough returns a Scheduler that runs every job immediately.
|
||||
func NewPassthrough(reg *registry.Registry) *Passthrough {
|
||||
return &Passthrough{reg: reg}
|
||||
}
|
||||
|
||||
// Run executes fn straight away, only tracking in-flight count for stats.
|
||||
func (p *Passthrough) Run(ctx context.Context, consumer string, fn Job) error {
|
||||
release := p.reg.MarkActive(consumer)
|
||||
defer release()
|
||||
|
||||
p.mu.Lock()
|
||||
p.inFlight++
|
||||
p.total++
|
||||
p.mu.Unlock()
|
||||
defer func() {
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
err := fn(ctx)
|
||||
elapsed := time.Since(start).Milliseconds()
|
||||
p.mu.Lock()
|
||||
p.lastRunMS = elapsed
|
||||
p.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Stats returns current passthrough statistics.
|
||||
func (p *Passthrough) Stats() Stats {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return Stats{
|
||||
InFlight: p.inFlight,
|
||||
TotalJobs: p.total,
|
||||
LastRunMS: p.lastRunMS,
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time interface guard.
|
||||
var _ Scheduler = (*Passthrough)(nil)
|
||||
|
||||
// Ensure config package is imported (used by later Schritte that read
|
||||
// per-consumer max_concurrency and vram_resident_mib).
|
||||
var _ = config.KindTTS
|
||||
379
internal/server/server.go
Normal file
379
internal/server/server.go
Normal file
@@ -0,0 +1,379 @@
|
||||
// Package server is the HTTP façade of mGPUmanager.
|
||||
//
|
||||
// It exposes:
|
||||
// - POST /v1/tts, /v1/stt, /v1/llm, /v1/image — pass-through proxy to the
|
||||
// consumer named in config.Routing[kind].
|
||||
// - GET /audio/* — proxy to config.AudioProxy (mvoice's audio directory).
|
||||
// - GET /v1/status — live snapshot of consumers + GPU + scheduler.
|
||||
// - GET /healthz — broker liveness (200 if process is up).
|
||||
//
|
||||
// Every proxy call goes through the Scheduler so that, in Schritt 4 and 5,
|
||||
// queueing and eviction can be added without touching server.go.
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/gpu"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/scheduler"
|
||||
)
|
||||
|
||||
// Server bundles the HTTP handlers + dependencies.
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
reg *registry.Registry
|
||||
gpu *gpu.Poller
|
||||
sched scheduler.Scheduler
|
||||
client *http.Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New builds a Server. Caller owns the lifecycle of reg/gpu/sched.
|
||||
func New(cfg *config.Config, reg *registry.Registry, gpuPoller *gpu.Poller, sched scheduler.Scheduler, logger *slog.Logger) *Server {
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
reg: reg,
|
||||
gpu: gpuPoller,
|
||||
sched: sched,
|
||||
client: &http.Client{Timeout: 120 * time.Second}, // TTS can take 5-10s; image gen up to 60s
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns the root mux. Caller wraps it in http.Server.
|
||||
func (s *Server) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("POST /v1/tts", s.handleEndpoint(config.KindTTS))
|
||||
mux.HandleFunc("POST /v1/stt", s.handleEndpoint(config.KindSTT))
|
||||
mux.HandleFunc("POST /v1/llm", s.handleEndpoint(config.KindLLM))
|
||||
mux.HandleFunc("POST /v1/image", s.handleEndpoint(config.KindImage))
|
||||
|
||||
if s.cfg.AudioProxy != "" && s.cfg.AudioPathPrefix != "" {
|
||||
mux.HandleFunc("GET "+s.cfg.AudioPathPrefix, s.handleAudio)
|
||||
}
|
||||
mux.HandleFunc("GET /v1/status", s.handleStatus)
|
||||
mux.HandleFunc("GET /healthz", s.handleHealthz)
|
||||
mux.HandleFunc("GET /", s.handleRoot)
|
||||
|
||||
return logMiddleware(s.logger, mux)
|
||||
}
|
||||
|
||||
// ───── error envelope ─────────────────────────────────────────────────────
|
||||
|
||||
// errorBody is the broker's structured error envelope. Every non-2xx response
|
||||
// from mGPUmanager itself uses this shape. (Pass-through 4xx/5xx from
|
||||
// consumers are forwarded verbatim so callers see the original payload.)
|
||||
type errorBody struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Consumer string `json:"consumer,omitempty"`
|
||||
Retryable bool `json:"retryable"`
|
||||
}
|
||||
|
||||
func writeErr(w http.ResponseWriter, status int, code, msg, consumer string, retryable bool) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(errorBody{
|
||||
Error: code,
|
||||
Message: msg,
|
||||
Consumer: consumer,
|
||||
Retryable: retryable,
|
||||
})
|
||||
}
|
||||
|
||||
// ───── endpoint proxy ─────────────────────────────────────────────────────
|
||||
|
||||
// handleEndpoint returns the http.HandlerFunc for a /v1/<kind> endpoint.
|
||||
func (s *Server) handleEndpoint(kind config.EndpointKind) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
consName, cons := s.cfg.ConsumerForKind(kind)
|
||||
if cons == nil {
|
||||
writeErr(w, http.StatusNotImplemented, "no_consumer",
|
||||
fmt.Sprintf("no consumer routes %s", kind), "", false)
|
||||
return
|
||||
}
|
||||
route, ok := cons.Paths[kind]
|
||||
if !ok {
|
||||
writeErr(w, http.StatusNotImplemented, "no_consumer",
|
||||
fmt.Sprintf("consumer %s lacks paths.%s", consName, kind), consName, false)
|
||||
return
|
||||
}
|
||||
|
||||
// Refuse fast if the consumer is unhealthy (last probe failed) — keeps
|
||||
// Felix-Banholzer-style silent-fallback impossible.
|
||||
st := s.reg.Get(consName)
|
||||
if !st.Healthy && !st.LastProbe.IsZero() {
|
||||
writeErr(w, http.StatusServiceUnavailable, "consumer_unreachable",
|
||||
fmt.Sprintf("consumer %s last probe failed: %s", consName, st.LastError),
|
||||
consName, true)
|
||||
return
|
||||
}
|
||||
|
||||
err := s.sched.Run(r.Context(), consName, func(ctx context.Context) error {
|
||||
return s.proxyRequest(ctx, w, r, cons, route, consName)
|
||||
})
|
||||
if err != nil && !responseStarted(w) {
|
||||
writeErr(w, http.StatusInternalServerError, "scheduler_error",
|
||||
err.Error(), consName, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyRequest forwards the inbound HTTP request to a consumer route and
|
||||
// streams the response back. Errors before the consumer responds are surfaced
|
||||
// as the broker's structured error envelope; once the consumer has begun
|
||||
// responding we stream its bytes through unchanged.
|
||||
func (s *Server) proxyRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, cons *config.Consumer, route config.Route, consumer string) error {
|
||||
target, err := url.Parse(cons.URL)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "bad_consumer_url",
|
||||
err.Error(), consumer, false)
|
||||
return nil
|
||||
}
|
||||
target.Path = route.Path
|
||||
// Forward inbound query string verbatim.
|
||||
target.RawQuery = r.URL.RawQuery
|
||||
|
||||
method := route.Method
|
||||
if method == "" {
|
||||
method = r.Method
|
||||
}
|
||||
|
||||
upstream, err := http.NewRequestWithContext(ctx, method, target.String(), r.Body)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "bad_request",
|
||||
err.Error(), consumer, false)
|
||||
return nil
|
||||
}
|
||||
// Copy through Content-Type, Content-Length and Accept (don't carry Host).
|
||||
for _, h := range []string{"Content-Type", "Content-Length", "Accept", "Accept-Encoding"} {
|
||||
if v := r.Header.Get(h); v != "" {
|
||||
upstream.Header.Set(h, v)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(upstream)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadGateway, "consumer_unreachable",
|
||||
fmt.Sprintf("upstream %s: %v", target.Host, err), consumer, true)
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Stream response.
|
||||
for k, vs := range resp.Header {
|
||||
if strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Transfer-Encoding") {
|
||||
continue
|
||||
}
|
||||
for _, v := range vs {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ───── audio proxy ────────────────────────────────────────────────────────
|
||||
|
||||
// handleAudio forwards GET /audio/<file> to the audio_proxy consumer (mvoice).
|
||||
// wa.sh fetches the rendered .wav via this path after /v1/tts returns its URL.
|
||||
func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) {
|
||||
if s.cfg.AudioProxy == "" {
|
||||
writeErr(w, http.StatusNotFound, "no_audio_proxy",
|
||||
"audio_proxy is not configured", "", false)
|
||||
return
|
||||
}
|
||||
cons, ok := s.cfg.Consumers[s.cfg.AudioProxy]
|
||||
if !ok {
|
||||
writeErr(w, http.StatusInternalServerError, "no_audio_proxy",
|
||||
"audio_proxy points at unknown consumer", s.cfg.AudioProxy, false)
|
||||
return
|
||||
}
|
||||
target, err := url.Parse(cons.URL)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "bad_consumer_url",
|
||||
err.Error(), s.cfg.AudioProxy, false)
|
||||
return
|
||||
}
|
||||
target.Path = r.URL.Path
|
||||
target.RawQuery = r.URL.RawQuery
|
||||
|
||||
upstream, err := http.NewRequestWithContext(r.Context(), http.MethodGet, target.String(), nil)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "bad_request",
|
||||
err.Error(), s.cfg.AudioProxy, false)
|
||||
return
|
||||
}
|
||||
resp, err := s.client.Do(upstream)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadGateway, "consumer_unreachable",
|
||||
fmt.Sprintf("upstream %s: %v", target.Host, err), s.cfg.AudioProxy, true)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
// ───── status ─────────────────────────────────────────────────────────────
|
||||
|
||||
type statusResponse struct {
|
||||
Listen string `json:"listen"`
|
||||
Time time.Time `json:"time"`
|
||||
GPU statusGPU `json:"gpu"`
|
||||
Routing map[config.EndpointKind]string `json:"routing"`
|
||||
Consumers []statusConsumer `json:"consumers"`
|
||||
Scheduler scheduler.Stats `json:"scheduler"`
|
||||
}
|
||||
|
||||
type statusGPU struct {
|
||||
TotalMiB int `json:"total_mib"`
|
||||
UsedMiB int `json:"used_mib"`
|
||||
FreeMiB int `json:"free_mib"`
|
||||
ReservedMiB int `json:"reserved_mib"`
|
||||
LastSample time.Time `json:"last_sample"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
type statusConsumer struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Healthy bool `json:"healthy"`
|
||||
Loaded bool `json:"loaded"`
|
||||
GPUResidentMiB int `json:"gpu_resident_mib"`
|
||||
VRAMBudgetMiB int `json:"vram_budget_mib"`
|
||||
Active int `json:"active"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
LastUsed time.Time `json:"last_used,omitzero"`
|
||||
LastProbe time.Time `json:"last_probe,omitzero"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
CanCoexistWith []string `json:"can_coexist_with"`
|
||||
}
|
||||
|
||||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
sample := s.gpu.Last()
|
||||
snap := s.reg.Snapshot()
|
||||
|
||||
resp := statusResponse{
|
||||
Listen: s.cfg.Listen,
|
||||
Time: time.Now(),
|
||||
Routing: s.cfg.Routing,
|
||||
GPU: statusGPU{
|
||||
TotalMiB: s.cfg.GPU.TotalMiB,
|
||||
UsedMiB: sample.UsedMiB,
|
||||
FreeMiB: sample.FreeMiB,
|
||||
ReservedMiB: s.cfg.GPU.ReservedMiB,
|
||||
LastSample: sample.At,
|
||||
Err: sample.Err,
|
||||
},
|
||||
Scheduler: s.sched.Stats(),
|
||||
}
|
||||
if resp.GPU.TotalMiB == 0 && sample.TotalMiB > 0 {
|
||||
resp.GPU.TotalMiB = sample.TotalMiB
|
||||
}
|
||||
|
||||
// Stable ordering by config-declared name.
|
||||
names := make([]string, 0, len(s.cfg.Consumers))
|
||||
for n := range s.cfg.Consumers {
|
||||
names = append(names, n)
|
||||
}
|
||||
sortStrings(names)
|
||||
for _, n := range names {
|
||||
cons := s.cfg.Consumers[n]
|
||||
st := snap[n]
|
||||
resp.Consumers = append(resp.Consumers, statusConsumer{
|
||||
Name: n,
|
||||
URL: cons.URL,
|
||||
Healthy: st.Healthy,
|
||||
Loaded: st.Loaded,
|
||||
GPUResidentMiB: st.GPUResidentMiB,
|
||||
VRAMBudgetMiB: cons.VRAMResidentMiB,
|
||||
Active: st.Active,
|
||||
TotalRequests: st.TotalRequests,
|
||||
LastUsed: st.LastUsed,
|
||||
LastProbe: st.LastProbe,
|
||||
LastError: st.LastError,
|
||||
Priority: cons.Priority,
|
||||
CanCoexistWith: cons.CanCoexistWith,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *Server) handleHealthz(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
}
|
||||
|
||||
func (s *Server) handleRoot(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = io.Copy(w, bytes.NewReader([]byte(
|
||||
"mGPUmanager — see GET /v1/status for live state, POST /v1/{tts,stt,llm,image} for inference\n",
|
||||
)))
|
||||
}
|
||||
|
||||
// ───── helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
// responseStarted is a coarse heuristic: once we've written headers, we can't
|
||||
// switch to the error envelope. The proxy path writes headers only inside
|
||||
// proxyRequest, which catches its own errors before that point.
|
||||
func responseStarted(_ http.ResponseWriter) bool { return false }
|
||||
|
||||
// sortStrings: avoid pulling in "sort" everywhere this file uses ordering.
|
||||
func sortStrings(s []string) {
|
||||
for i := 1; i < len(s); i++ {
|
||||
for j := i; j > 0 && s[j-1] > s[j]; j-- {
|
||||
s[j-1], s[j] = s[j], s[j-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// logMiddleware emits one structured request log per call.
|
||||
func logMiddleware(logger *slog.Logger, next http.Handler) http.Handler {
|
||||
if logger == nil {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
lw := &statusCapture{ResponseWriter: w, code: 200}
|
||||
next.ServeHTTP(lw, r)
|
||||
logger.Info("http",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", lw.code,
|
||||
"ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
type statusCapture struct {
|
||||
http.ResponseWriter
|
||||
code int
|
||||
}
|
||||
|
||||
func (s *statusCapture) WriteHeader(code int) {
|
||||
s.code = code
|
||||
s.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
237
internal/server/server_test.go
Normal file
237
internal/server/server_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/gpu"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/scheduler"
|
||||
)
|
||||
|
||||
// silentLogger discards everything; keeps test output clean.
|
||||
func silentLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
}
|
||||
|
||||
// fakeMVoice spins up a httptest.Server that mimics the relevant mvoice
|
||||
// endpoints just enough to exercise the broker's proxy behaviour.
|
||||
func fakeMVoice(t *testing.T) (*httptest.Server, *atomic.Int32) {
|
||||
t.Helper()
|
||||
calls := &atomic.Int32{}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"ready","loaded":true,"gpu_resident_mib":2800}`))
|
||||
})
|
||||
mux.HandleFunc("POST /api/synthesize", func(w http.ResponseWriter, r *http.Request) {
|
||||
calls.Add(1)
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"audio_url":"/api/audio/test.wav","payload_echo":"` +
|
||||
strings.ReplaceAll(string(body), `"`, `\"`) + `"}`))
|
||||
})
|
||||
mux.HandleFunc("GET /api/audio/test.wav", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "audio/wav")
|
||||
_, _ = w.Write([]byte("RIFF....fake-wav-bytes"))
|
||||
})
|
||||
return httptest.NewServer(mux), calls
|
||||
}
|
||||
|
||||
func buildHarness(t *testing.T, upstream *httptest.Server) (*Server, *registry.Registry, *gpu.Poller) {
|
||||
t.Helper()
|
||||
cfg := &config.Config{
|
||||
Listen: "127.0.0.1:0",
|
||||
GPU: config.GPU{TotalMiB: 16376, ReservedMiB: 1024, PollIntervalSeconds: 2},
|
||||
Routing: map[config.EndpointKind]string{
|
||||
config.KindTTS: "mvoice",
|
||||
},
|
||||
AudioProxy: "mvoice",
|
||||
AudioPathPrefix: "/api/audio/",
|
||||
Consumers: map[string]*config.Consumer{
|
||||
"mvoice": {
|
||||
URL: upstream.URL,
|
||||
Health: config.Route{Method: "GET", Path: "/api/health"},
|
||||
Paths: map[config.EndpointKind]config.Route{
|
||||
config.KindTTS: {Method: "POST", Path: "/api/synthesize"},
|
||||
},
|
||||
VRAMResidentMiB: 2800,
|
||||
MaxConcurrency: 1,
|
||||
Priority: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
reg := registry.New(cfg, silentLogger())
|
||||
poller := gpu.NewPoller(time.Second, silentLogger())
|
||||
sched := scheduler.NewPassthrough(reg)
|
||||
srv := New(cfg, reg, poller, sched, silentLogger())
|
||||
return srv, reg, poller
|
||||
}
|
||||
|
||||
func TestProxyTTSForwardsBodyAndReturnsConsumerResponse(t *testing.T) {
|
||||
upstream, calls := fakeMVoice(t)
|
||||
defer upstream.Close()
|
||||
|
||||
srv, reg, _ := buildHarness(t, upstream)
|
||||
// Force-mark mvoice healthy so the gating in handleEndpoint passes.
|
||||
probeOnce(t, reg, upstream.URL+"/api/health", "mvoice")
|
||||
|
||||
ts := httptest.NewServer(srv.Handler())
|
||||
defer ts.Close()
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("text", "Hallo")
|
||||
resp, err := http.Post(ts.URL+"/v1/tts", "application/x-www-form-urlencoded",
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Errorf("upstream calls = %d, want 1", got)
|
||||
}
|
||||
var payload struct {
|
||||
AudioURL string `json:"audio_url"`
|
||||
PayloadEcho string `json:"payload_echo"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(payload.PayloadEcho, "text=Hallo") {
|
||||
t.Errorf("upstream did not receive body: %q", payload.PayloadEcho)
|
||||
}
|
||||
if payload.AudioURL != "/api/audio/test.wav" {
|
||||
t.Errorf("AudioURL = %q", payload.AudioURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAudioProxyForwardsBytes(t *testing.T) {
|
||||
upstream, _ := fakeMVoice(t)
|
||||
defer upstream.Close()
|
||||
|
||||
srv, _, _ := buildHarness(t, upstream)
|
||||
|
||||
ts := httptest.NewServer(srv.Handler())
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Get(ts.URL + "/api/audio/test.wav")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("status %d", resp.StatusCode)
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if !strings.HasPrefix(string(body), "RIFF") {
|
||||
t.Errorf("audio body did not start with RIFF: %q", body[:8])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnhealthyConsumerReturns503(t *testing.T) {
|
||||
// Upstream that always 500s health probe.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "boom", http.StatusInternalServerError)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
srv, reg, _ := buildHarness(t, upstream)
|
||||
probeOnce(t, reg, upstream.URL+"/api/health", "mvoice")
|
||||
|
||||
ts := httptest.NewServer(srv.Handler())
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Post(ts.URL+"/v1/tts", "application/json", strings.NewReader(`{}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 503 {
|
||||
t.Fatalf("status = %d, want 503", resp.StatusCode)
|
||||
}
|
||||
var env errorBody
|
||||
if err := json.NewDecoder(resp.Body).Decode(&env); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if env.Error != "consumer_unreachable" {
|
||||
t.Errorf("error code = %q", env.Error)
|
||||
}
|
||||
if !env.Retryable {
|
||||
t.Errorf("retryable should be true for unreachable consumer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusReturnsConsumers(t *testing.T) {
|
||||
upstream, _ := fakeMVoice(t)
|
||||
defer upstream.Close()
|
||||
|
||||
srv, reg, _ := buildHarness(t, upstream)
|
||||
probeOnce(t, reg, upstream.URL+"/api/health", "mvoice")
|
||||
|
||||
ts := httptest.NewServer(srv.Handler())
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Get(ts.URL + "/v1/status")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
var sr statusResponse
|
||||
if err := json.Unmarshal(body, &sr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(sr.Consumers) != 1 {
|
||||
t.Fatalf("consumers = %d", len(sr.Consumers))
|
||||
}
|
||||
if sr.Consumers[0].Name != "mvoice" {
|
||||
t.Errorf("name = %q", sr.Consumers[0].Name)
|
||||
}
|
||||
if !sr.Consumers[0].Healthy {
|
||||
t.Errorf("expected healthy=true after successful probe")
|
||||
}
|
||||
if sr.GPU.TotalMiB != 16376 {
|
||||
t.Errorf("GPU.TotalMiB = %d", sr.GPU.TotalMiB)
|
||||
}
|
||||
}
|
||||
|
||||
// probeOnce drives the registry to record one successful (or failed) health
|
||||
// probe synchronously so tests don't have to wait for the 5s loop.
|
||||
func probeOnce(t *testing.T, reg *registry.Registry, healthURL, name string) {
|
||||
t.Helper()
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", healthURL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
reg.RecordProbeForTest(name, false, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
ok := resp.StatusCode < 400
|
||||
errMsg := ""
|
||||
if !ok {
|
||||
errMsg = "status " + strings.TrimSpace(resp.Status)
|
||||
}
|
||||
reg.RecordProbeForTest(name, ok, errMsg, body)
|
||||
}
|
||||
30
systemd/mgpumanager.service
Normal file
30
systemd/mgpumanager.service
Normal file
@@ -0,0 +1,30 @@
|
||||
[Unit]
|
||||
Description=mGPUmanager — GPU-Inference-Control-Plane for mRock
|
||||
Documentation=https://mgit.msbls.de/m/mGPUmanager
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=m
|
||||
Group=m
|
||||
WorkingDirectory=/home/m/dev/mGPUmanager
|
||||
ExecStart=/home/m/dev/mGPUmanager/bin/mgpumanager \
|
||||
--config /home/m/dev/mGPUmanager/config/consumers.yaml \
|
||||
--log-level info
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
TimeoutStopSec=10
|
||||
|
||||
# Hardening — broker has no need for elevated capabilities.
|
||||
NoNewPrivileges=true
|
||||
PrivateTmp=true
|
||||
ProtectSystem=strict
|
||||
ProtectHome=read-only
|
||||
ReadWritePaths=/home/m/dev/mGPUmanager
|
||||
|
||||
# The broker only proxies; nvidia-smi is the only GPU-touching call.
|
||||
PrivateDevices=false
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
Reference in New Issue
Block a user