feat: Schritt 4 — Locked scheduler (global GPU lock, queue, stats)

Replaces the MVP Passthrough with scheduler.Locked: a capacity-1 channel
serialises every consumer's GPU work end-to-end. main.go switches to it.

Behavioural contract:
- Jobs that arrive while another job holds the GPU block on the channel
  until the holder finishes. Context cancellation aborts the wait
  cleanly (no leaked tokens, queue depth decremented).
- Stats track queue_depth, in_flight, total_jobs, last_wait_ms,
  last_run_ms, oldest_queued — surfaced through /v1/status.
- One lock for ALL consumers (not per-consumer): the design (§4.3) is
  explicit that grobgranular > GPU-stream-granular on single-GPU
  single-user hardware. mvoice + ollama + comfyui never run truly
  concurrently any more, which is the whole point — that's what
  produced the CUDA-OOM under load.

Tests:
- 5 goroutines hammer the scheduler concurrently → max in-flight = 1.
- Cancellation while parked on the lock returns ctx.Err() and frees
  the queue slot.
- Stats reflect in-flight + queue-depth transitions correctly.
- Race detector clean.

Schritt 5 will compose this with VRAM-pressure eviction: before
acquiring the lock, check if the target consumer's resident cost fits
under the current GPU headroom; if not, unload the LRU non-coexistent
consumer first.

Refs: m/mGPUmanager#1 (Schritt 4).
This commit is contained in:
mAi
2026-05-11 13:33:39 +02:00
parent c81c145163
commit 3b3d828e9e
7 changed files with 315 additions and 13 deletions

View File

@@ -61,7 +61,9 @@ func main() {
reg := registry.New(cfg, logger.With("component", "registry"))
gpuPoller := gpu.NewPoller(cfg.GPU.PollInterval(), logger.With("component", "gpu"))
sched := scheduler.NewPassthrough(reg)
// Phase 1 always runs a single-slot global GPU lock. Schritt 5's
// eviction-aware scheduler wraps this same lock with VRAM pressure logic.
sched := scheduler.NewLocked(reg, 1)
go reg.Run(ctx)
go gpuPoller.Run(ctx)

View File

@@ -11,8 +11,10 @@ routing:
llm: ollama
image: comfyui
# Audio download proxy: GET /audio/* forwards to this consumer.
# Audio download proxy: any GET under audio_path_prefix is forwarded to this
# consumer at the same path. wa.sh fetches mvoice's generated WAV this way.
audio_proxy: mvoice
audio_path_prefix: /api/audio/
consumers:
mvoice:

View File

@@ -83,11 +83,12 @@ func (g GPU) AvailableMiB() int {
// Config is the parsed mGPUmanager configuration.
type Config struct {
Listen string `yaml:"listen"`
GPU GPU `yaml:"gpu"`
Routing map[EndpointKind]string `yaml:"routing"`
AudioProxy string `yaml:"audio_proxy"`
Consumers map[string]*Consumer `yaml:"consumers"`
Listen string `yaml:"listen"`
GPU GPU `yaml:"gpu"`
Routing map[EndpointKind]string `yaml:"routing"`
AudioProxy string `yaml:"audio_proxy"`
AudioPathPrefix string `yaml:"audio_path_prefix"`
Consumers map[string]*Consumer `yaml:"consumers"`
}
// Load reads and validates a consumers.yaml file from disk.
@@ -150,6 +151,12 @@ func (c *Config) validate() error {
if _, ok := c.Consumers[c.AudioProxy]; !ok {
return fmt.Errorf("audio_proxy: unknown consumer %q", c.AudioProxy)
}
if c.AudioPathPrefix == "" {
c.AudioPathPrefix = "/api/audio/"
}
if !strings.HasPrefix(c.AudioPathPrefix, "/") || !strings.HasSuffix(c.AudioPathPrefix, "/") {
return fmt.Errorf("audio_path_prefix must start and end with '/': %q", c.AudioPathPrefix)
}
}
return nil
}

View File

@@ -0,0 +1,124 @@
package scheduler
import (
"context"
"sync"
"time"
"mgit.msbls.de/m/mGPUmanager/internal/config"
"mgit.msbls.de/m/mGPUmanager/internal/registry"
)
// Locked is the Schritt 4 scheduler: a single capacity-1 semaphore serialises
// every consumer's GPU work. Jobs wait FIFO-ish at the channel until the lock
// is available, then run to completion.
//
// Why one global lock instead of per-stream or per-consumer:
// - mRock is single-GPU + single-user. Theoretical parallelism (e.g. mvoice
// + ollama small model coexisting) is given up to gain predictability:
// no more CUDA-OOM races between concurrent loaders.
// - The design (§4.3) is explicit: "Der Lock ist grobgranular (ein Mutex)
// […]. Wir verschenken theoretische Parallelität, gewinnen dafür
// Vorhersagbarkeit."
//
// Schritt 5 wraps this with eviction logic that runs before sem acquire when
// the requested consumer's resident cost would exceed available headroom.
type Locked struct {
reg *registry.Registry
gpuLock chan struct{} // capacity-1 = global mutex with cancellable acquire
mu sync.Mutex
inFlight int
queueDepth int
total int64
lastWaitMS int64
lastRunMS int64
oldestQueued time.Time
}
// NewLocked returns the serialising scheduler. capacity is the number of
// concurrent jobs allowed on the GPU (Phase 1 wires this as 1).
func NewLocked(reg *registry.Registry, capacity int) *Locked {
if capacity < 1 {
capacity = 1
}
return &Locked{
reg: reg,
gpuLock: make(chan struct{}, capacity),
}
}
// Run acquires the global GPU lock, executes fn while holding it, and
// releases. Cancellation via ctx aborts the wait without leaking a token.
func (s *Locked) Run(ctx context.Context, consumer string, fn Job) error {
release := s.reg.MarkActive(consumer)
defer release()
queuedAt := time.Now()
s.mu.Lock()
s.queueDepth++
if s.queueDepth == 1 || s.oldestQueued.IsZero() {
s.oldestQueued = queuedAt
}
s.mu.Unlock()
// Acquire global lock or bail on cancellation.
select {
case s.gpuLock <- struct{}{}:
case <-ctx.Done():
s.mu.Lock()
s.queueDepth--
if s.queueDepth == 0 {
s.oldestQueued = time.Time{}
}
s.mu.Unlock()
return ctx.Err()
}
waitMS := time.Since(queuedAt).Milliseconds()
s.mu.Lock()
s.queueDepth--
if s.queueDepth == 0 {
s.oldestQueued = time.Time{}
}
s.inFlight++
s.total++
s.lastWaitMS = waitMS
s.mu.Unlock()
defer func() {
<-s.gpuLock
s.mu.Lock()
s.inFlight--
s.mu.Unlock()
}()
start := time.Now()
err := fn(ctx)
runMS := time.Since(start).Milliseconds()
s.mu.Lock()
s.lastRunMS = runMS
s.mu.Unlock()
return err
}
// Stats reports current depth + last timings for /v1/status.
func (s *Locked) Stats() Stats {
s.mu.Lock()
defer s.mu.Unlock()
return Stats{
QueueDepth: s.queueDepth,
InFlight: s.inFlight,
TotalJobs: s.total,
LastWaitMS: s.lastWaitMS,
LastRunMS: s.lastRunMS,
OldestQueued: s.oldestQueued,
}
}
// Compile-time interface guard.
var _ Scheduler = (*Locked)(nil)
// Unused import guard — keeps the config package edge live for Schritt 5's
// VRAM-pressure evaluation, which reads cfg in this same package.
var _ = config.KindTTS

View File

@@ -0,0 +1,164 @@
package scheduler
import (
"context"
"errors"
"log/slog"
"io"
"sync"
"sync/atomic"
"testing"
"time"
"mgit.msbls.de/m/mGPUmanager/internal/config"
"mgit.msbls.de/m/mGPUmanager/internal/registry"
)
func newReg() *registry.Registry {
cfg := &config.Config{
Consumers: map[string]*config.Consumer{
"mvoice": {
URL: "http://localhost:8766",
Health: config.Route{Method: "GET", Path: "/api/health"},
},
},
}
return registry.New(cfg, slog.New(slog.NewTextHandler(io.Discard, nil)))
}
// TestLockedSerialisesConcurrentJobs is the regression test for the
// CUDA-OOM-from-parallel-loaders class: two TTS calls that arrive at the
// same time must run sequentially, not concurrently.
func TestLockedSerialisesConcurrentJobs(t *testing.T) {
sched := NewLocked(newReg(), 1)
var maxConcurrent atomic.Int32
var inFlight atomic.Int32
job := func(ctx context.Context) error {
now := inFlight.Add(1)
// Update max in a CAS loop (small N, never contested in practice).
for {
cur := maxConcurrent.Load()
if now <= cur || maxConcurrent.CompareAndSwap(cur, now) {
break
}
}
time.Sleep(20 * time.Millisecond)
inFlight.Add(-1)
return nil
}
var wg sync.WaitGroup
const n = 5
for range n {
wg.Go(func() {
if err := sched.Run(context.Background(), "mvoice", job); err != nil {
t.Errorf("Run: %v", err)
}
})
}
wg.Wait()
if got := maxConcurrent.Load(); got != 1 {
t.Fatalf("max concurrent jobs = %d, want 1", got)
}
if got := sched.Stats().TotalJobs; got != n {
t.Errorf("Stats.TotalJobs = %d, want %d", got, n)
}
}
func TestLockedRespectsContextCancel(t *testing.T) {
sched := NewLocked(newReg(), 1)
// Hold the lock with a long-running job.
started := make(chan struct{})
done := make(chan struct{})
go func() {
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
close(started)
<-done
return nil
})
}()
<-started
// Now try to run with a context that we'll cancel.
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
errCh <- sched.Run(ctx, "mvoice", func(ctx context.Context) error {
t.Error("second job should not run after cancellation")
return nil
})
}()
// Give the second job time to queue.
time.Sleep(10 * time.Millisecond)
cancel()
select {
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("got err=%v, want context.Canceled", err)
}
case <-time.After(time.Second):
t.Fatal("cancelled Run did not return within 1s")
}
// Release the holder.
close(done)
}
func TestLockedStatsTrackInFlightAndQueue(t *testing.T) {
sched := NewLocked(newReg(), 1)
jobStart := make(chan struct{})
jobBlock := make(chan struct{})
go func() {
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
close(jobStart)
<-jobBlock
return nil
})
}()
<-jobStart
// Inside the holding job: InFlight==1, QueueDepth==0.
s := sched.Stats()
if s.InFlight != 1 {
t.Errorf("InFlight while holding = %d, want 1", s.InFlight)
}
if s.QueueDepth != 0 {
t.Errorf("QueueDepth = %d, want 0", s.QueueDepth)
}
// Queue a waiter and verify QueueDepth grows.
waitStarted := make(chan struct{})
waitDone := make(chan struct{})
go func() {
close(waitStarted)
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
return nil
})
close(waitDone)
}()
<-waitStarted
// Wait for the waiter to actually be parked on the channel.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if sched.Stats().QueueDepth == 1 {
break
}
time.Sleep(2 * time.Millisecond)
}
if got := sched.Stats().QueueDepth; got != 1 {
t.Errorf("QueueDepth with one waiter = %d, want 1", got)
}
close(jobBlock)
<-waitDone
if got := sched.Stats().InFlight; got != 0 {
t.Errorf("InFlight after both done = %d, want 0", got)
}
}

View File

@@ -60,7 +60,9 @@ func (s *Server) Handler() http.Handler {
mux.HandleFunc("POST /v1/llm", s.handleEndpoint(config.KindLLM))
mux.HandleFunc("POST /v1/image", s.handleEndpoint(config.KindImage))
mux.HandleFunc("GET /audio/", s.handleAudio)
if s.cfg.AudioProxy != "" && s.cfg.AudioPathPrefix != "" {
mux.HandleFunc("GET "+s.cfg.AudioPathPrefix, s.handleAudio)
}
mux.HandleFunc("GET /v1/status", s.handleStatus)
mux.HandleFunc("GET /healthz", s.handleHealthz)
mux.HandleFunc("GET /", s.handleRoot)

View File

@@ -38,10 +38,10 @@ func fakeMVoice(t *testing.T) (*httptest.Server, *atomic.Int32) {
calls.Add(1)
body, _ := io.ReadAll(r.Body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"audio_url":"/audio/test.wav","payload_echo":"` +
_, _ = w.Write([]byte(`{"audio_url":"/api/audio/test.wav","payload_echo":"` +
strings.ReplaceAll(string(body), `"`, `\"`) + `"}`))
})
mux.HandleFunc("GET /audio/test.wav", func(w http.ResponseWriter, _ *http.Request) {
mux.HandleFunc("GET /api/audio/test.wav", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "audio/wav")
_, _ = w.Write([]byte("RIFF....fake-wav-bytes"))
})
@@ -56,7 +56,8 @@ func buildHarness(t *testing.T, upstream *httptest.Server) (*Server, *registry.R
Routing: map[config.EndpointKind]string{
config.KindTTS: "mvoice",
},
AudioProxy: "mvoice",
AudioProxy: "mvoice",
AudioPathPrefix: "/api/audio/",
Consumers: map[string]*config.Consumer{
"mvoice": {
URL: upstream.URL,
@@ -113,7 +114,7 @@ func TestProxyTTSForwardsBodyAndReturnsConsumerResponse(t *testing.T) {
if !strings.Contains(payload.PayloadEcho, "text=Hallo") {
t.Errorf("upstream did not receive body: %q", payload.PayloadEcho)
}
if payload.AudioURL != "/audio/test.wav" {
if payload.AudioURL != "/api/audio/test.wav" {
t.Errorf("AudioURL = %q", payload.AudioURL)
}
}
@@ -127,7 +128,7 @@ func TestAudioProxyForwardsBytes(t *testing.T) {
ts := httptest.NewServer(srv.Handler())
defer ts.Close()
resp, err := http.Get(ts.URL + "/audio/test.wav")
resp, err := http.Get(ts.URL + "/api/audio/test.wav")
if err != nil {
t.Fatal(err)
}