package services import ( "context" "errors" "fmt" "strings" "sync/atomic" "testing" "time" "github.com/google/uuid" ) // testSession is the per-user session name we pass into healthGate / // callShim from tests. The shape mirrors what RunTurn would derive for // a real user. const testSession = "paliad-paliadin-deadbeef" // Tests for the remote-Paliadin backend. Every test bypasses exec via // the callShimHook field — no real ssh is ever invoked, no DB rows are // written. Tests that would need DB I/O (audit row insert/complete on // RunTurn) are not in scope here; paliad's test suite has no sqlx mock // and the existing paliadin_test.go only covers pure functions. func TestNewRemotePaliadinService_Defaults(t *testing.T) { s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ SSHHost: "100.99.98.203", // SSHPort + SSHUser intentionally left zero/empty }) if s.cfg.SSHPort != 22022 { t.Errorf("SSHPort default = %d; want 22022 (Tailscale-SSH bypass port)", s.cfg.SSHPort) } if s.cfg.SSHUser != "m" { t.Errorf("SSHUser default = %q; want %q", s.cfg.SSHUser, "m") } if s.cfg.SSHHost != "100.99.98.203" { t.Errorf("SSHHost not preserved: %q", s.cfg.SSHHost) } } func TestNewRemotePaliadinService_HonoursOverrides(t *testing.T) { s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ SSHHost: "10.0.0.1", SSHPort: 2222, SSHUser: "alice", }) if s.cfg.SSHPort != 2222 { t.Errorf("SSHPort override lost: %d", s.cfg.SSHPort) } if s.cfg.SSHUser != "alice" { t.Errorf("SSHUser override lost: %q", s.cfg.SSHUser) } } func TestClassifySSHError(t *testing.T) { cases := []struct { name string err error want string }{ {"nil", nil, ""}, {"explicit ErrMRiverUnreachable", ErrMRiverUnreachable, "mriver_unreachable"}, {"wrapped ErrMRiverUnreachable", fmt.Errorf("foo: %w", ErrMRiverUnreachable), "mriver_unreachable"}, {"context deadline", context.DeadlineExceeded, "timeout"}, {"shim run-turn timeout (exit 124)", errors.New("ssh run-turn …: exit status 124 (stderr: response timeout)"), "timeout"}, {"connection refused", errors.New("ssh health: dial: Connection refused"), "mriver_unreachable"}, {"connection timed out", errors.New("ssh health: Connection timed out"), "mriver_unreachable"}, {"permission denied", errors.New("ssh: Permission denied (publickey)"), "shim_auth_failed"}, {"unknown", errors.New("ssh: some other failure"), "shim_error"}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { got := classifySSHError(c.err) if got != c.want { t.Errorf("classifySSHError(%v) = %q; want %q", c.err, got, c.want) } }) } } func TestHealthGate_CachesOnSuccess(t *testing.T) { var calls int32 s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { atomic.AddInt32(&calls, 1) if len(args) != 2 || args[0] != "health" || args[1] != testSession { t.Errorf("unexpected callShim args: %v", args) } return []byte("ok\n"), nil } for i := 0; i < 5; i++ { if err := s.healthGate(context.Background(), testSession); err != nil { t.Fatalf("healthGate iteration %d: %v", i, err) } } if got := atomic.LoadInt32(&calls); got != 1 { t.Errorf("expected 1 callShim call (cached); got %d", got) } } func TestHealthGate_RetriesAfterFailure(t *testing.T) { var calls int32 s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { atomic.AddInt32(&calls, 1) return nil, errors.New("ssh: Connection refused") } for i := 0; i < 3; i++ { err := s.healthGate(context.Background(), testSession) if !errors.Is(err, ErrMRiverUnreachable) { t.Errorf("iteration %d: err %v; want wrapping ErrMRiverUnreachable", i, err) } } // Failed health is NOT cached — every call re-probes. if got := atomic.LoadInt32(&calls); got != 3 { t.Errorf("expected 3 callShim calls (no caching on failure); got %d", got) } } func TestHealthGate_RejectsUnexpectedReply(t *testing.T) { s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { return []byte("not-ok"), nil } err := s.healthGate(context.Background(), testSession) if !errors.Is(err, ErrMRiverUnreachable) { t.Errorf("err = %v; want wrap of ErrMRiverUnreachable for non-ok reply", err) } } func TestHealthGate_PerSessionCache(t *testing.T) { // Two sessions must each get their own probe — caching is per-key, // not global. var calls int32 s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { atomic.AddInt32(&calls, 1) return []byte("ok"), nil } if err := s.healthGate(context.Background(), "paliad-paliadin-aaaaaaaa"); err != nil { t.Fatalf("session A first probe: %v", err) } if err := s.healthGate(context.Background(), "paliad-paliadin-bbbbbbbb"); err != nil { t.Fatalf("session B first probe: %v", err) } if err := s.healthGate(context.Background(), "paliad-paliadin-aaaaaaaa"); err != nil { t.Fatalf("session A second probe: %v", err) } if got := atomic.LoadInt32(&calls); got != 2 { t.Errorf("expected 2 callShim calls (1 per session, A reuses cache on 3rd); got %d", got) } } func TestHealthGate_CacheExpires(t *testing.T) { var calls int32 s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { atomic.AddInt32(&calls, 1) return []byte("ok"), nil } if err := s.healthGate(context.Background(), testSession); err != nil { t.Fatalf("first probe: %v", err) } // Force the cached timestamp to expire. s.healthMu.Lock() s.health[testSession] = healthCacheEntry{ok: true, checkedAt: time.Now().Add(-11 * time.Second)} s.healthMu.Unlock() if err := s.healthGate(context.Background(), testSession); err != nil { t.Fatalf("second probe (expired cache): %v", err) } if got := atomic.LoadInt32(&calls); got != 2 { t.Errorf("expected 2 callShim calls (cache expired between); got %d", got) } } func TestSessionNameFor_PerUser(t *testing.T) { s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) a := uuid.MustParse("aaaaaaaa-1111-2222-3333-444444444444") b := uuid.MustParse("bbbbbbbb-1111-2222-3333-444444444444") if got := s.sessionNameFor(a); got != "paliad-paliadin-aaaaaaaa" { t.Errorf("session A = %q; want paliad-paliadin-aaaaaaaa", got) } if got := s.sessionNameFor(b); got != "paliad-paliadin-bbbbbbbb" { t.Errorf("session B = %q; want paliad-paliadin-bbbbbbbb", got) } if s.sessionNameFor(a) == s.sessionNameFor(b) { t.Error("distinct user IDs collapsed to the same session") } } func TestSessionNameFor_HonoursPrefix(t *testing.T) { s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ SSHHost: "x", SessionPrefix: "custom", }) a := uuid.MustParse("12345678-1111-2222-3333-444444444444") if got := s.sessionNameFor(a); got != "custom-12345678" { t.Errorf("session = %q; want custom-12345678", got) } } func TestResetSession_KillsPerUserSession(t *testing.T) { var captured []string s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { captured = append([]string(nil), args...) return []byte("ok"), nil } uid := uuid.MustParse("aaaaaaaa-1111-2222-3333-444444444444") if err := s.ResetSession(context.Background(), uid); err != nil { t.Fatalf("ResetSession: %v", err) } want := []string{"reset", "paliad-paliadin-aaaaaaaa"} if len(captured) != 2 || captured[0] != want[0] || captured[1] != want[1] { t.Errorf("callShim args = %v; want %v", captured, want) } } func TestResetSession_DropsHealthCache(t *testing.T) { s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { return []byte("ok"), nil } uid := uuid.MustParse("aaaaaaaa-1111-2222-3333-444444444444") session := s.sessionNameFor(uid) // Warm the cache. if err := s.healthGate(context.Background(), session); err != nil { t.Fatalf("warm: %v", err) } if _, ok := s.health[session]; !ok { t.Fatal("cache should be warm") } if err := s.ResetSession(context.Background(), uid); err != nil { t.Fatalf("ResetSession: %v", err) } if _, ok := s.health[session]; ok { t.Error("ResetSession must drop the per-session health cache") } } func TestRemotePaliadin_ImplementsPaliadin(t *testing.T) { // Compile-time check is in paliadin_remote.go; this test makes the // failure mode obvious if someone accidentally drops a method. var _ Paliadin = (*RemotePaliadinService)(nil) var _ Paliadin = (*LocalPaliadinService)(nil) var _ Paliadin = (*DisabledPaliadinService)(nil) } func TestDisabledPaliadinService(t *testing.T) { s := NewDisabledPaliadinService(nil, nil) if _, err := s.RunTurn(context.Background(), TurnRequest{}); !errors.Is(err, ErrPaliadinDisabled) { t.Errorf("RunTurn error = %v; want ErrPaliadinDisabled", err) } if err := s.ResetSession(context.Background(), uuid.Nil); !errors.Is(err, ErrPaliadinDisabled) { t.Errorf("ResetSession error = %v; want ErrPaliadinDisabled", err) } } func TestCallShim_SSHArgvShape(t *testing.T) { // Verify the ssh argv we'd construct includes the bypass-port flag, // the key + known_hosts paths, and the verb after `--`. We don't // actually exec ssh — we set callShimHook so callShim never reaches // the exec path; this test just guards the constructor wiring. s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ SSHHost: "100.99.98.203", SSHPort: 22022, SSHUser: "m", SSHKeyPath: "/tmp/k", KnownHostsPath: "/tmp/kh", }) var captured []string s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { captured = append([]string(nil), args...) return []byte("ok"), nil } _, _ = s.callShim(context.Background(), "health") if len(captured) != 1 || captured[0] != "health" { t.Errorf("callShim forwarded args = %v; want [health]", captured) } } func TestCallShim_StderrSurfacesInError(t *testing.T) { // When the real exec path fails, callShim wraps stderr into the // returned error so classifySSHError can pattern-match. Simulate // that contract via the hook. s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { return nil, errors.New("ssh health: exit status 1 (stderr: Permission denied (publickey))") } _, err := s.callShim(context.Background(), "health") if err == nil { t.Fatal("expected error") } if !strings.Contains(err.Error(), "Permission denied") { t.Errorf("error should preserve stderr: %v", err) } if classifySSHError(err) != "shim_auth_failed" { t.Errorf("classifier should pick up Permission denied; got %q", classifySSHError(err)) } }