feat: Schritt 4 — Locked scheduler (global GPU lock, queue, stats)
Replaces the MVP Passthrough with scheduler.Locked: a capacity-1 channel serialises every consumer's GPU work end-to-end. main.go switches to it. Behavioural contract: - Jobs that arrive while another job holds the GPU block on the channel until the holder finishes. Context cancellation aborts the wait cleanly (no leaked tokens, queue depth decremented). - Stats track queue_depth, in_flight, total_jobs, last_wait_ms, last_run_ms, oldest_queued — surfaced through /v1/status. - One lock for ALL consumers (not per-consumer): the design (§4.3) is explicit that grobgranular > GPU-stream-granular on single-GPU single-user hardware. mvoice + ollama + comfyui never run truly concurrently any more, which is the whole point — that's what produced the CUDA-OOM under load. Tests: - 5 goroutines hammer the scheduler concurrently → max in-flight = 1. - Cancellation while parked on the lock returns ctx.Err() and frees the queue slot. - Stats reflect in-flight + queue-depth transitions correctly. - Race detector clean. Schritt 5 will compose this with VRAM-pressure eviction: before acquiring the lock, check if the target consumer's resident cost fits under the current GPU headroom; if not, unload the LRU non-coexistent consumer first. Refs: m/mGPUmanager#1 (Schritt 4).
This commit is contained in:
@@ -61,7 +61,9 @@ func main() {
|
||||
|
||||
reg := registry.New(cfg, logger.With("component", "registry"))
|
||||
gpuPoller := gpu.NewPoller(cfg.GPU.PollInterval(), logger.With("component", "gpu"))
|
||||
sched := scheduler.NewPassthrough(reg)
|
||||
// Phase 1 always runs a single-slot global GPU lock. Schritt 5's
|
||||
// eviction-aware scheduler wraps this same lock with VRAM pressure logic.
|
||||
sched := scheduler.NewLocked(reg, 1)
|
||||
|
||||
go reg.Run(ctx)
|
||||
go gpuPoller.Run(ctx)
|
||||
|
||||
@@ -11,8 +11,10 @@ routing:
|
||||
llm: ollama
|
||||
image: comfyui
|
||||
|
||||
# Audio download proxy: GET /audio/* forwards to this consumer.
|
||||
# Audio download proxy: any GET under audio_path_prefix is forwarded to this
|
||||
# consumer at the same path. wa.sh fetches mvoice's generated WAV this way.
|
||||
audio_proxy: mvoice
|
||||
audio_path_prefix: /api/audio/
|
||||
|
||||
consumers:
|
||||
mvoice:
|
||||
|
||||
@@ -83,11 +83,12 @@ func (g GPU) AvailableMiB() int {
|
||||
|
||||
// Config is the parsed mGPUmanager configuration.
|
||||
type Config struct {
|
||||
Listen string `yaml:"listen"`
|
||||
GPU GPU `yaml:"gpu"`
|
||||
Routing map[EndpointKind]string `yaml:"routing"`
|
||||
AudioProxy string `yaml:"audio_proxy"`
|
||||
Consumers map[string]*Consumer `yaml:"consumers"`
|
||||
Listen string `yaml:"listen"`
|
||||
GPU GPU `yaml:"gpu"`
|
||||
Routing map[EndpointKind]string `yaml:"routing"`
|
||||
AudioProxy string `yaml:"audio_proxy"`
|
||||
AudioPathPrefix string `yaml:"audio_path_prefix"`
|
||||
Consumers map[string]*Consumer `yaml:"consumers"`
|
||||
}
|
||||
|
||||
// Load reads and validates a consumers.yaml file from disk.
|
||||
@@ -150,6 +151,12 @@ func (c *Config) validate() error {
|
||||
if _, ok := c.Consumers[c.AudioProxy]; !ok {
|
||||
return fmt.Errorf("audio_proxy: unknown consumer %q", c.AudioProxy)
|
||||
}
|
||||
if c.AudioPathPrefix == "" {
|
||||
c.AudioPathPrefix = "/api/audio/"
|
||||
}
|
||||
if !strings.HasPrefix(c.AudioPathPrefix, "/") || !strings.HasSuffix(c.AudioPathPrefix, "/") {
|
||||
return fmt.Errorf("audio_path_prefix must start and end with '/': %q", c.AudioPathPrefix)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
124
internal/scheduler/locked.go
Normal file
124
internal/scheduler/locked.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
// Locked is the Schritt 4 scheduler: a single capacity-1 semaphore serialises
|
||||
// every consumer's GPU work. Jobs wait FIFO-ish at the channel until the lock
|
||||
// is available, then run to completion.
|
||||
//
|
||||
// Why one global lock instead of per-stream or per-consumer:
|
||||
// - mRock is single-GPU + single-user. Theoretical parallelism (e.g. mvoice
|
||||
// + ollama small model coexisting) is given up to gain predictability:
|
||||
// no more CUDA-OOM races between concurrent loaders.
|
||||
// - The design (§4.3) is explicit: "Der Lock ist grobgranular (ein Mutex)
|
||||
// […]. Wir verschenken theoretische Parallelität, gewinnen dafür
|
||||
// Vorhersagbarkeit."
|
||||
//
|
||||
// Schritt 5 wraps this with eviction logic that runs before sem acquire when
|
||||
// the requested consumer's resident cost would exceed available headroom.
|
||||
type Locked struct {
|
||||
reg *registry.Registry
|
||||
gpuLock chan struct{} // capacity-1 = global mutex with cancellable acquire
|
||||
|
||||
mu sync.Mutex
|
||||
inFlight int
|
||||
queueDepth int
|
||||
total int64
|
||||
lastWaitMS int64
|
||||
lastRunMS int64
|
||||
oldestQueued time.Time
|
||||
}
|
||||
|
||||
// NewLocked returns the serialising scheduler. capacity is the number of
|
||||
// concurrent jobs allowed on the GPU (Phase 1 wires this as 1).
|
||||
func NewLocked(reg *registry.Registry, capacity int) *Locked {
|
||||
if capacity < 1 {
|
||||
capacity = 1
|
||||
}
|
||||
return &Locked{
|
||||
reg: reg,
|
||||
gpuLock: make(chan struct{}, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
// Run acquires the global GPU lock, executes fn while holding it, and
|
||||
// releases. Cancellation via ctx aborts the wait without leaking a token.
|
||||
func (s *Locked) Run(ctx context.Context, consumer string, fn Job) error {
|
||||
release := s.reg.MarkActive(consumer)
|
||||
defer release()
|
||||
|
||||
queuedAt := time.Now()
|
||||
s.mu.Lock()
|
||||
s.queueDepth++
|
||||
if s.queueDepth == 1 || s.oldestQueued.IsZero() {
|
||||
s.oldestQueued = queuedAt
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Acquire global lock or bail on cancellation.
|
||||
select {
|
||||
case s.gpuLock <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
s.mu.Lock()
|
||||
s.queueDepth--
|
||||
if s.queueDepth == 0 {
|
||||
s.oldestQueued = time.Time{}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return ctx.Err()
|
||||
}
|
||||
waitMS := time.Since(queuedAt).Milliseconds()
|
||||
|
||||
s.mu.Lock()
|
||||
s.queueDepth--
|
||||
if s.queueDepth == 0 {
|
||||
s.oldestQueued = time.Time{}
|
||||
}
|
||||
s.inFlight++
|
||||
s.total++
|
||||
s.lastWaitMS = waitMS
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
<-s.gpuLock
|
||||
s.mu.Lock()
|
||||
s.inFlight--
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
err := fn(ctx)
|
||||
runMS := time.Since(start).Milliseconds()
|
||||
s.mu.Lock()
|
||||
s.lastRunMS = runMS
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Stats reports current depth + last timings for /v1/status.
|
||||
func (s *Locked) Stats() Stats {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return Stats{
|
||||
QueueDepth: s.queueDepth,
|
||||
InFlight: s.inFlight,
|
||||
TotalJobs: s.total,
|
||||
LastWaitMS: s.lastWaitMS,
|
||||
LastRunMS: s.lastRunMS,
|
||||
OldestQueued: s.oldestQueued,
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time interface guard.
|
||||
var _ Scheduler = (*Locked)(nil)
|
||||
|
||||
// Unused import guard — keeps the config package edge live for Schritt 5's
|
||||
// VRAM-pressure evaluation, which reads cfg in this same package.
|
||||
var _ = config.KindTTS
|
||||
164
internal/scheduler/locked_test.go
Normal file
164
internal/scheduler/locked_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/config"
|
||||
"mgit.msbls.de/m/mGPUmanager/internal/registry"
|
||||
)
|
||||
|
||||
func newReg() *registry.Registry {
|
||||
cfg := &config.Config{
|
||||
Consumers: map[string]*config.Consumer{
|
||||
"mvoice": {
|
||||
URL: "http://localhost:8766",
|
||||
Health: config.Route{Method: "GET", Path: "/api/health"},
|
||||
},
|
||||
},
|
||||
}
|
||||
return registry.New(cfg, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
}
|
||||
|
||||
// TestLockedSerialisesConcurrentJobs is the regression test for the
|
||||
// CUDA-OOM-from-parallel-loaders class: two TTS calls that arrive at the
|
||||
// same time must run sequentially, not concurrently.
|
||||
func TestLockedSerialisesConcurrentJobs(t *testing.T) {
|
||||
sched := NewLocked(newReg(), 1)
|
||||
|
||||
var maxConcurrent atomic.Int32
|
||||
var inFlight atomic.Int32
|
||||
|
||||
job := func(ctx context.Context) error {
|
||||
now := inFlight.Add(1)
|
||||
// Update max in a CAS loop (small N, never contested in practice).
|
||||
for {
|
||||
cur := maxConcurrent.Load()
|
||||
if now <= cur || maxConcurrent.CompareAndSwap(cur, now) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
inFlight.Add(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const n = 5
|
||||
for range n {
|
||||
wg.Go(func() {
|
||||
if err := sched.Run(context.Background(), "mvoice", job); err != nil {
|
||||
t.Errorf("Run: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := maxConcurrent.Load(); got != 1 {
|
||||
t.Fatalf("max concurrent jobs = %d, want 1", got)
|
||||
}
|
||||
if got := sched.Stats().TotalJobs; got != n {
|
||||
t.Errorf("Stats.TotalJobs = %d, want %d", got, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockedRespectsContextCancel(t *testing.T) {
|
||||
sched := NewLocked(newReg(), 1)
|
||||
|
||||
// Hold the lock with a long-running job.
|
||||
started := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
|
||||
close(started)
|
||||
<-done
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
<-started
|
||||
|
||||
// Now try to run with a context that we'll cancel.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- sched.Run(ctx, "mvoice", func(ctx context.Context) error {
|
||||
t.Error("second job should not run after cancellation")
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
// Give the second job time to queue.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("got err=%v, want context.Canceled", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("cancelled Run did not return within 1s")
|
||||
}
|
||||
|
||||
// Release the holder.
|
||||
close(done)
|
||||
}
|
||||
|
||||
func TestLockedStatsTrackInFlightAndQueue(t *testing.T) {
|
||||
sched := NewLocked(newReg(), 1)
|
||||
|
||||
jobStart := make(chan struct{})
|
||||
jobBlock := make(chan struct{})
|
||||
go func() {
|
||||
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
|
||||
close(jobStart)
|
||||
<-jobBlock
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
<-jobStart
|
||||
|
||||
// Inside the holding job: InFlight==1, QueueDepth==0.
|
||||
s := sched.Stats()
|
||||
if s.InFlight != 1 {
|
||||
t.Errorf("InFlight while holding = %d, want 1", s.InFlight)
|
||||
}
|
||||
if s.QueueDepth != 0 {
|
||||
t.Errorf("QueueDepth = %d, want 0", s.QueueDepth)
|
||||
}
|
||||
|
||||
// Queue a waiter and verify QueueDepth grows.
|
||||
waitStarted := make(chan struct{})
|
||||
waitDone := make(chan struct{})
|
||||
go func() {
|
||||
close(waitStarted)
|
||||
_ = sched.Run(context.Background(), "mvoice", func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
close(waitDone)
|
||||
}()
|
||||
<-waitStarted
|
||||
// Wait for the waiter to actually be parked on the channel.
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if sched.Stats().QueueDepth == 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
if got := sched.Stats().QueueDepth; got != 1 {
|
||||
t.Errorf("QueueDepth with one waiter = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(jobBlock)
|
||||
<-waitDone
|
||||
if got := sched.Stats().InFlight; got != 0 {
|
||||
t.Errorf("InFlight after both done = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
@@ -60,7 +60,9 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.HandleFunc("POST /v1/llm", s.handleEndpoint(config.KindLLM))
|
||||
mux.HandleFunc("POST /v1/image", s.handleEndpoint(config.KindImage))
|
||||
|
||||
mux.HandleFunc("GET /audio/", s.handleAudio)
|
||||
if s.cfg.AudioProxy != "" && s.cfg.AudioPathPrefix != "" {
|
||||
mux.HandleFunc("GET "+s.cfg.AudioPathPrefix, s.handleAudio)
|
||||
}
|
||||
mux.HandleFunc("GET /v1/status", s.handleStatus)
|
||||
mux.HandleFunc("GET /healthz", s.handleHealthz)
|
||||
mux.HandleFunc("GET /", s.handleRoot)
|
||||
|
||||
@@ -38,10 +38,10 @@ func fakeMVoice(t *testing.T) (*httptest.Server, *atomic.Int32) {
|
||||
calls.Add(1)
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"audio_url":"/audio/test.wav","payload_echo":"` +
|
||||
_, _ = w.Write([]byte(`{"audio_url":"/api/audio/test.wav","payload_echo":"` +
|
||||
strings.ReplaceAll(string(body), `"`, `\"`) + `"}`))
|
||||
})
|
||||
mux.HandleFunc("GET /audio/test.wav", func(w http.ResponseWriter, _ *http.Request) {
|
||||
mux.HandleFunc("GET /api/audio/test.wav", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "audio/wav")
|
||||
_, _ = w.Write([]byte("RIFF....fake-wav-bytes"))
|
||||
})
|
||||
@@ -56,7 +56,8 @@ func buildHarness(t *testing.T, upstream *httptest.Server) (*Server, *registry.R
|
||||
Routing: map[config.EndpointKind]string{
|
||||
config.KindTTS: "mvoice",
|
||||
},
|
||||
AudioProxy: "mvoice",
|
||||
AudioProxy: "mvoice",
|
||||
AudioPathPrefix: "/api/audio/",
|
||||
Consumers: map[string]*config.Consumer{
|
||||
"mvoice": {
|
||||
URL: upstream.URL,
|
||||
@@ -113,7 +114,7 @@ func TestProxyTTSForwardsBodyAndReturnsConsumerResponse(t *testing.T) {
|
||||
if !strings.Contains(payload.PayloadEcho, "text=Hallo") {
|
||||
t.Errorf("upstream did not receive body: %q", payload.PayloadEcho)
|
||||
}
|
||||
if payload.AudioURL != "/audio/test.wav" {
|
||||
if payload.AudioURL != "/api/audio/test.wav" {
|
||||
t.Errorf("AudioURL = %q", payload.AudioURL)
|
||||
}
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func TestAudioProxyForwardsBytes(t *testing.T) {
|
||||
ts := httptest.NewServer(srv.Handler())
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := http.Get(ts.URL + "/audio/test.wav")
|
||||
resp, err := http.Get(ts.URL + "/api/audio/test.wav")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user