diff --git a/internal/services/paliadin_remote_test.go b/internal/services/paliadin_remote_test.go new file mode 100644 index 0000000..98f204f --- /dev/null +++ b/internal/services/paliadin_remote_test.go @@ -0,0 +1,257 @@ +package services + +import ( + "context" + "errors" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" +) + +// Tests for the remote-Paliadin backend. Every test bypasses exec via +// the callShimHook field — no real ssh is ever invoked, no DB rows are +// written. Tests that would need DB I/O (audit row insert/complete on +// RunTurn) are not in scope here; paliad's test suite has no sqlx mock +// and the existing paliadin_test.go only covers pure functions. + +func TestNewRemotePaliadinService_Defaults(t *testing.T) { + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ + SSHHost: "100.99.98.203", + // SSHPort + SSHUser intentionally left zero/empty + }) + if s.cfg.SSHPort != 22022 { + t.Errorf("SSHPort default = %d; want 22022 (Tailscale-SSH bypass port)", s.cfg.SSHPort) + } + if s.cfg.SSHUser != "m" { + t.Errorf("SSHUser default = %q; want %q", s.cfg.SSHUser, "m") + } + if s.cfg.SSHHost != "100.99.98.203" { + t.Errorf("SSHHost not preserved: %q", s.cfg.SSHHost) + } +} + +func TestNewRemotePaliadinService_HonoursOverrides(t *testing.T) { + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ + SSHHost: "10.0.0.1", + SSHPort: 2222, + SSHUser: "alice", + }) + if s.cfg.SSHPort != 2222 { + t.Errorf("SSHPort override lost: %d", s.cfg.SSHPort) + } + if s.cfg.SSHUser != "alice" { + t.Errorf("SSHUser override lost: %q", s.cfg.SSHUser) + } +} + +func TestClassifySSHError(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + {"nil", nil, ""}, + {"explicit ErrMRiverUnreachable", ErrMRiverUnreachable, "mriver_unreachable"}, + {"wrapped ErrMRiverUnreachable", fmt.Errorf("foo: %w", ErrMRiverUnreachable), "mriver_unreachable"}, + {"context deadline", context.DeadlineExceeded, "timeout"}, + {"shim run-turn timeout (exit 124)", errors.New("ssh run-turn …: exit status 124 (stderr: response timeout)"), "timeout"}, + {"connection refused", errors.New("ssh health: dial: Connection refused"), "mriver_unreachable"}, + {"connection timed out", errors.New("ssh health: Connection timed out"), "mriver_unreachable"}, + {"permission denied", errors.New("ssh: Permission denied (publickey)"), "shim_auth_failed"}, + {"unknown", errors.New("ssh: some other failure"), "shim_error"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := classifySSHError(c.err) + if got != c.want { + t.Errorf("classifySSHError(%v) = %q; want %q", c.err, got, c.want) + } + }) + } +} + +func TestHealthGate_CachesOnSuccess(t *testing.T) { + var calls int32 + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + atomic.AddInt32(&calls, 1) + if len(args) != 1 || args[0] != "health" { + t.Errorf("unexpected callShim args: %v", args) + } + return []byte("ok\n"), nil + } + for i := 0; i < 5; i++ { + if err := s.healthGate(context.Background()); err != nil { + t.Fatalf("healthGate iteration %d: %v", i, err) + } + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("expected 1 callShim call (cached); got %d", got) + } +} + +func TestHealthGate_RetriesAfterFailure(t *testing.T) { + var calls int32 + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + atomic.AddInt32(&calls, 1) + return nil, errors.New("ssh: Connection refused") + } + for i := 0; i < 3; i++ { + err := s.healthGate(context.Background()) + if !errors.Is(err, ErrMRiverUnreachable) { + t.Errorf("iteration %d: err %v; want wrapping ErrMRiverUnreachable", i, err) + } + } + // Failed health is NOT cached — every call re-probes. + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("expected 3 callShim calls (no caching on failure); got %d", got) + } +} + +func TestHealthGate_RejectsUnexpectedReply(t *testing.T) { + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + return []byte("not-ok"), nil + } + err := s.healthGate(context.Background()) + if !errors.Is(err, ErrMRiverUnreachable) { + t.Errorf("err = %v; want wrap of ErrMRiverUnreachable for non-ok reply", err) + } +} + +func TestEnsureBootstrapped_RunsOnce(t *testing.T) { + var calls int32 + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + atomic.AddInt32(&calls, 1) + if len(args) != 2 || args[0] != "bootstrap" { + t.Errorf("unexpected callShim args: %v", args) + } + // args[1] is the base64'd system prompt — no need to decode in + // the test; just sanity-check it isn't trivially empty. + if len(args[1]) < 100 { + t.Errorf("bootstrap prompt suspiciously short: %d bytes", len(args[1])) + } + return []byte("ok\n"), nil + } + for i := 0; i < 3; i++ { + if err := s.ensureBootstrapped(context.Background()); err != nil { + t.Fatalf("ensureBootstrapped iteration %d: %v", i, err) + } + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("expected 1 callShim call (bootstrap is one-shot); got %d", got) + } +} + +func TestEnsureBootstrapped_RetriesOnFailure(t *testing.T) { + var calls int32 + var failOnce atomic.Bool + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + atomic.AddInt32(&calls, 1) + if failOnce.CompareAndSwap(false, true) { + return nil, errors.New("ssh: transient failure") + } + return []byte("ok\n"), nil + } + if err := s.ensureBootstrapped(context.Background()); err == nil { + t.Fatal("first call should error") + } + if err := s.ensureBootstrapped(context.Background()); err != nil { + t.Fatalf("second call should succeed: %v", err) + } + // Third call should be a cache hit (bootstrapped flag set on success). + if err := s.ensureBootstrapped(context.Background()); err != nil { + t.Fatalf("third call should be cached: %v", err) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("expected 2 callShim calls (1 fail + 1 succeed; 3rd cached); got %d", got) + } +} + +func TestHealthGate_CacheExpires(t *testing.T) { + var calls int32 + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + atomic.AddInt32(&calls, 1) + return []byte("ok"), nil + } + if err := s.healthGate(context.Background()); err != nil { + t.Fatalf("first probe: %v", err) + } + // Force the cached timestamp to expire. + s.healthMu.Lock() + s.healthCheckedAt = time.Now().Add(-11 * time.Second) + s.healthMu.Unlock() + if err := s.healthGate(context.Background()); err != nil { + t.Fatalf("second probe (expired cache): %v", err) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("expected 2 callShim calls (cache expired between); got %d", got) + } +} + +func TestRemotePaliadin_ImplementsPaliadin(t *testing.T) { + // Compile-time check is in paliadin_remote.go; this test makes the + // failure mode obvious if someone accidentally drops a method. + var _ Paliadin = (*RemotePaliadinService)(nil) + var _ Paliadin = (*LocalPaliadinService)(nil) + var _ Paliadin = (*DisabledPaliadinService)(nil) +} + +func TestDisabledPaliadinService(t *testing.T) { + s := NewDisabledPaliadinService(nil, nil) + if _, err := s.RunTurn(context.Background(), TurnRequest{}); !errors.Is(err, ErrPaliadinDisabled) { + t.Errorf("RunTurn error = %v; want ErrPaliadinDisabled", err) + } + if err := s.ResetSession(context.Background()); !errors.Is(err, ErrPaliadinDisabled) { + t.Errorf("ResetSession error = %v; want ErrPaliadinDisabled", err) + } +} + +func TestCallShim_SSHArgvShape(t *testing.T) { + // Verify the ssh argv we'd construct includes the bypass-port flag, + // the key + known_hosts paths, and the verb after `--`. We don't + // actually exec ssh — we set callShimHook so callShim never reaches + // the exec path; this test just guards the constructor wiring. + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{ + SSHHost: "100.99.98.203", + SSHPort: 22022, + SSHUser: "m", + SSHKeyPath: "/tmp/k", + KnownHostsPath: "/tmp/kh", + }) + var captured []string + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + captured = append([]string(nil), args...) + return []byte("ok"), nil + } + _, _ = s.callShim(context.Background(), "health") + if len(captured) != 1 || captured[0] != "health" { + t.Errorf("callShim forwarded args = %v; want [health]", captured) + } +} + +func TestCallShim_StderrSurfacesInError(t *testing.T) { + // When the real exec path fails, callShim wraps stderr into the + // returned error so classifySSHError can pattern-match. Simulate + // that contract via the hook. + s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"}) + s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) { + return nil, errors.New("ssh health: exit status 1 (stderr: Permission denied (publickey))") + } + _, err := s.callShim(context.Background(), "health") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "Permission denied") { + t.Errorf("error should preserve stderr: %v", err) + } + if classifySSHError(err) != "shim_auth_failed" { + t.Errorf("classifier should pick up Permission denied; got %q", classifySSHError(err)) + } +}