Files
paliad/internal/services/paliadin_remote_test.go
m 68c56ea920 test(t-paliad-151): paliadin_remote_test.go — RemotePaliadinService unit tests
14 tests covering:
- NewRemotePaliadinService default values (SSHPort=22022, SSHUser="m")
- NewRemotePaliadinService honours overrides
- classifySSHError mapping (nil / explicit + wrapped ErrMRiverUnreachable
  / context.DeadlineExceeded / shim exit-124 timeout / Connection
  refused/timed out / Permission denied / unknown fallback)
- healthGate caches OK results for 10 s
- healthGate does NOT cache failures (every call re-probes)
- healthGate rejects unexpected shim replies (returns wrap of
  ErrMRiverUnreachable)
- healthGate cache expires after 10 s wall clock
- ensureBootstrapped runs exactly once on success (idempotent)
- ensureBootstrapped retries after failure, then caches the success
- DisabledPaliadinService returns ErrPaliadinDisabled from RunTurn +
  ResetSession
- compile-time Paliadin interface conformance for all three impls
- callShim forwards args verbatim through the test hook
- callShim error-wrapping path preserves stderr (so classifySSHError
  can pattern-match Permission denied / Connection refused etc.)

All tests bypass exec via the callShimHook field — no real ssh, no
real DB. RunTurn audit-row tests are out of scope (paliad has no
sqlx mock; existing paliadin_test.go also stays on pure functions).

Refs m/paliad#12
2026-05-08 02:18:08 +02:00

258 lines
9.1 KiB
Go

package services
import (
"context"
"errors"
"fmt"
"strings"
"sync/atomic"
"testing"
"time"
)
// Tests for the remote-Paliadin backend. Every test bypasses exec via
// the callShimHook field — no real ssh is ever invoked, no DB rows are
// written. Tests that would need DB I/O (audit row insert/complete on
// RunTurn) are not in scope here; paliad's test suite has no sqlx mock
// and the existing paliadin_test.go only covers pure functions.
func TestNewRemotePaliadinService_Defaults(t *testing.T) {
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{
SSHHost: "100.99.98.203",
// SSHPort + SSHUser intentionally left zero/empty
})
if s.cfg.SSHPort != 22022 {
t.Errorf("SSHPort default = %d; want 22022 (Tailscale-SSH bypass port)", s.cfg.SSHPort)
}
if s.cfg.SSHUser != "m" {
t.Errorf("SSHUser default = %q; want %q", s.cfg.SSHUser, "m")
}
if s.cfg.SSHHost != "100.99.98.203" {
t.Errorf("SSHHost not preserved: %q", s.cfg.SSHHost)
}
}
func TestNewRemotePaliadinService_HonoursOverrides(t *testing.T) {
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{
SSHHost: "10.0.0.1",
SSHPort: 2222,
SSHUser: "alice",
})
if s.cfg.SSHPort != 2222 {
t.Errorf("SSHPort override lost: %d", s.cfg.SSHPort)
}
if s.cfg.SSHUser != "alice" {
t.Errorf("SSHUser override lost: %q", s.cfg.SSHUser)
}
}
func TestClassifySSHError(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{"nil", nil, ""},
{"explicit ErrMRiverUnreachable", ErrMRiverUnreachable, "mriver_unreachable"},
{"wrapped ErrMRiverUnreachable", fmt.Errorf("foo: %w", ErrMRiverUnreachable), "mriver_unreachable"},
{"context deadline", context.DeadlineExceeded, "timeout"},
{"shim run-turn timeout (exit 124)", errors.New("ssh run-turn …: exit status 124 (stderr: response timeout)"), "timeout"},
{"connection refused", errors.New("ssh health: dial: Connection refused"), "mriver_unreachable"},
{"connection timed out", errors.New("ssh health: Connection timed out"), "mriver_unreachable"},
{"permission denied", errors.New("ssh: Permission denied (publickey)"), "shim_auth_failed"},
{"unknown", errors.New("ssh: some other failure"), "shim_error"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := classifySSHError(c.err)
if got != c.want {
t.Errorf("classifySSHError(%v) = %q; want %q", c.err, got, c.want)
}
})
}
}
func TestHealthGate_CachesOnSuccess(t *testing.T) {
var calls int32
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
atomic.AddInt32(&calls, 1)
if len(args) != 1 || args[0] != "health" {
t.Errorf("unexpected callShim args: %v", args)
}
return []byte("ok\n"), nil
}
for i := 0; i < 5; i++ {
if err := s.healthGate(context.Background()); err != nil {
t.Fatalf("healthGate iteration %d: %v", i, err)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Errorf("expected 1 callShim call (cached); got %d", got)
}
}
func TestHealthGate_RetriesAfterFailure(t *testing.T) {
var calls int32
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
atomic.AddInt32(&calls, 1)
return nil, errors.New("ssh: Connection refused")
}
for i := 0; i < 3; i++ {
err := s.healthGate(context.Background())
if !errors.Is(err, ErrMRiverUnreachable) {
t.Errorf("iteration %d: err %v; want wrapping ErrMRiverUnreachable", i, err)
}
}
// Failed health is NOT cached — every call re-probes.
if got := atomic.LoadInt32(&calls); got != 3 {
t.Errorf("expected 3 callShim calls (no caching on failure); got %d", got)
}
}
func TestHealthGate_RejectsUnexpectedReply(t *testing.T) {
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
return []byte("not-ok"), nil
}
err := s.healthGate(context.Background())
if !errors.Is(err, ErrMRiverUnreachable) {
t.Errorf("err = %v; want wrap of ErrMRiverUnreachable for non-ok reply", err)
}
}
func TestEnsureBootstrapped_RunsOnce(t *testing.T) {
var calls int32
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
atomic.AddInt32(&calls, 1)
if len(args) != 2 || args[0] != "bootstrap" {
t.Errorf("unexpected callShim args: %v", args)
}
// args[1] is the base64'd system prompt — no need to decode in
// the test; just sanity-check it isn't trivially empty.
if len(args[1]) < 100 {
t.Errorf("bootstrap prompt suspiciously short: %d bytes", len(args[1]))
}
return []byte("ok\n"), nil
}
for i := 0; i < 3; i++ {
if err := s.ensureBootstrapped(context.Background()); err != nil {
t.Fatalf("ensureBootstrapped iteration %d: %v", i, err)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Errorf("expected 1 callShim call (bootstrap is one-shot); got %d", got)
}
}
func TestEnsureBootstrapped_RetriesOnFailure(t *testing.T) {
var calls int32
var failOnce atomic.Bool
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
atomic.AddInt32(&calls, 1)
if failOnce.CompareAndSwap(false, true) {
return nil, errors.New("ssh: transient failure")
}
return []byte("ok\n"), nil
}
if err := s.ensureBootstrapped(context.Background()); err == nil {
t.Fatal("first call should error")
}
if err := s.ensureBootstrapped(context.Background()); err != nil {
t.Fatalf("second call should succeed: %v", err)
}
// Third call should be a cache hit (bootstrapped flag set on success).
if err := s.ensureBootstrapped(context.Background()); err != nil {
t.Fatalf("third call should be cached: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 2 {
t.Errorf("expected 2 callShim calls (1 fail + 1 succeed; 3rd cached); got %d", got)
}
}
func TestHealthGate_CacheExpires(t *testing.T) {
var calls int32
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
atomic.AddInt32(&calls, 1)
return []byte("ok"), nil
}
if err := s.healthGate(context.Background()); err != nil {
t.Fatalf("first probe: %v", err)
}
// Force the cached timestamp to expire.
s.healthMu.Lock()
s.healthCheckedAt = time.Now().Add(-11 * time.Second)
s.healthMu.Unlock()
if err := s.healthGate(context.Background()); err != nil {
t.Fatalf("second probe (expired cache): %v", err)
}
if got := atomic.LoadInt32(&calls); got != 2 {
t.Errorf("expected 2 callShim calls (cache expired between); got %d", got)
}
}
func TestRemotePaliadin_ImplementsPaliadin(t *testing.T) {
// Compile-time check is in paliadin_remote.go; this test makes the
// failure mode obvious if someone accidentally drops a method.
var _ Paliadin = (*RemotePaliadinService)(nil)
var _ Paliadin = (*LocalPaliadinService)(nil)
var _ Paliadin = (*DisabledPaliadinService)(nil)
}
func TestDisabledPaliadinService(t *testing.T) {
s := NewDisabledPaliadinService(nil, nil)
if _, err := s.RunTurn(context.Background(), TurnRequest{}); !errors.Is(err, ErrPaliadinDisabled) {
t.Errorf("RunTurn error = %v; want ErrPaliadinDisabled", err)
}
if err := s.ResetSession(context.Background()); !errors.Is(err, ErrPaliadinDisabled) {
t.Errorf("ResetSession error = %v; want ErrPaliadinDisabled", err)
}
}
func TestCallShim_SSHArgvShape(t *testing.T) {
// Verify the ssh argv we'd construct includes the bypass-port flag,
// the key + known_hosts paths, and the verb after `--`. We don't
// actually exec ssh — we set callShimHook so callShim never reaches
// the exec path; this test just guards the constructor wiring.
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{
SSHHost: "100.99.98.203",
SSHPort: 22022,
SSHUser: "m",
SSHKeyPath: "/tmp/k",
KnownHostsPath: "/tmp/kh",
})
var captured []string
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
captured = append([]string(nil), args...)
return []byte("ok"), nil
}
_, _ = s.callShim(context.Background(), "health")
if len(captured) != 1 || captured[0] != "health" {
t.Errorf("callShim forwarded args = %v; want [health]", captured)
}
}
func TestCallShim_StderrSurfacesInError(t *testing.T) {
// When the real exec path fails, callShim wraps stderr into the
// returned error so classifySSHError can pattern-match. Simulate
// that contract via the hook.
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
return nil, errors.New("ssh health: exit status 1 (stderr: Permission denied (publickey))")
}
_, err := s.callShim(context.Background(), "health")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "Permission denied") {
t.Errorf("error should preserve stderr: %v", err)
}
if classifySSHError(err) != "shim_auth_failed" {
t.Errorf("classifier should pick up Permission denied; got %q", classifySSHError(err))
}
}