test(t-paliad-151): paliadin_remote_test.go — RemotePaliadinService unit tests
14 tests covering: - NewRemotePaliadinService default values (SSHPort=22022, SSHUser="m") - NewRemotePaliadinService honours overrides - classifySSHError mapping (nil / explicit + wrapped ErrMRiverUnreachable / context.DeadlineExceeded / shim exit-124 timeout / Connection refused/timed out / Permission denied / unknown fallback) - healthGate caches OK results for 10 s - healthGate does NOT cache failures (every call re-probes) - healthGate rejects unexpected shim replies (returns wrap of ErrMRiverUnreachable) - healthGate cache expires after 10 s wall clock - ensureBootstrapped runs exactly once on success (idempotent) - ensureBootstrapped retries after failure, then caches the success - DisabledPaliadinService returns ErrPaliadinDisabled from RunTurn + ResetSession - compile-time Paliadin interface conformance for all three impls - callShim forwards args verbatim through the test hook - callShim error-wrapping path preserves stderr (so classifySSHError can pattern-match Permission denied / Connection refused etc.) All tests bypass exec via the callShimHook field — no real ssh, no real DB. RunTurn audit-row tests are out of scope (paliad has no sqlx mock; existing paliadin_test.go also stays on pure functions). Refs m/paliad#12
This commit is contained in:
257
internal/services/paliadin_remote_test.go
Normal file
257
internal/services/paliadin_remote_test.go
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tests for the remote-Paliadin backend. Every test bypasses exec via
|
||||||
|
// the callShimHook field — no real ssh is ever invoked, no DB rows are
|
||||||
|
// written. Tests that would need DB I/O (audit row insert/complete on
|
||||||
|
// RunTurn) are not in scope here; paliad's test suite has no sqlx mock
|
||||||
|
// and the existing paliadin_test.go only covers pure functions.
|
||||||
|
|
||||||
|
func TestNewRemotePaliadinService_Defaults(t *testing.T) {
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{
|
||||||
|
SSHHost: "100.99.98.203",
|
||||||
|
// SSHPort + SSHUser intentionally left zero/empty
|
||||||
|
})
|
||||||
|
if s.cfg.SSHPort != 22022 {
|
||||||
|
t.Errorf("SSHPort default = %d; want 22022 (Tailscale-SSH bypass port)", s.cfg.SSHPort)
|
||||||
|
}
|
||||||
|
if s.cfg.SSHUser != "m" {
|
||||||
|
t.Errorf("SSHUser default = %q; want %q", s.cfg.SSHUser, "m")
|
||||||
|
}
|
||||||
|
if s.cfg.SSHHost != "100.99.98.203" {
|
||||||
|
t.Errorf("SSHHost not preserved: %q", s.cfg.SSHHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRemotePaliadinService_HonoursOverrides(t *testing.T) {
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{
|
||||||
|
SSHHost: "10.0.0.1",
|
||||||
|
SSHPort: 2222,
|
||||||
|
SSHUser: "alice",
|
||||||
|
})
|
||||||
|
if s.cfg.SSHPort != 2222 {
|
||||||
|
t.Errorf("SSHPort override lost: %d", s.cfg.SSHPort)
|
||||||
|
}
|
||||||
|
if s.cfg.SSHUser != "alice" {
|
||||||
|
t.Errorf("SSHUser override lost: %q", s.cfg.SSHUser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifySSHError(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"nil", nil, ""},
|
||||||
|
{"explicit ErrMRiverUnreachable", ErrMRiverUnreachable, "mriver_unreachable"},
|
||||||
|
{"wrapped ErrMRiverUnreachable", fmt.Errorf("foo: %w", ErrMRiverUnreachable), "mriver_unreachable"},
|
||||||
|
{"context deadline", context.DeadlineExceeded, "timeout"},
|
||||||
|
{"shim run-turn timeout (exit 124)", errors.New("ssh run-turn …: exit status 124 (stderr: response timeout)"), "timeout"},
|
||||||
|
{"connection refused", errors.New("ssh health: dial: Connection refused"), "mriver_unreachable"},
|
||||||
|
{"connection timed out", errors.New("ssh health: Connection timed out"), "mriver_unreachable"},
|
||||||
|
{"permission denied", errors.New("ssh: Permission denied (publickey)"), "shim_auth_failed"},
|
||||||
|
{"unknown", errors.New("ssh: some other failure"), "shim_error"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
got := classifySSHError(c.err)
|
||||||
|
if got != c.want {
|
||||||
|
t.Errorf("classifySSHError(%v) = %q; want %q", c.err, got, c.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthGate_CachesOnSuccess(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
if len(args) != 1 || args[0] != "health" {
|
||||||
|
t.Errorf("unexpected callShim args: %v", args)
|
||||||
|
}
|
||||||
|
return []byte("ok\n"), nil
|
||||||
|
}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
if err := s.healthGate(context.Background()); err != nil {
|
||||||
|
t.Fatalf("healthGate iteration %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Errorf("expected 1 callShim call (cached); got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthGate_RetriesAfterFailure(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return nil, errors.New("ssh: Connection refused")
|
||||||
|
}
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
err := s.healthGate(context.Background())
|
||||||
|
if !errors.Is(err, ErrMRiverUnreachable) {
|
||||||
|
t.Errorf("iteration %d: err %v; want wrapping ErrMRiverUnreachable", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Failed health is NOT cached — every call re-probes.
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 3 {
|
||||||
|
t.Errorf("expected 3 callShim calls (no caching on failure); got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthGate_RejectsUnexpectedReply(t *testing.T) {
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
return []byte("not-ok"), nil
|
||||||
|
}
|
||||||
|
err := s.healthGate(context.Background())
|
||||||
|
if !errors.Is(err, ErrMRiverUnreachable) {
|
||||||
|
t.Errorf("err = %v; want wrap of ErrMRiverUnreachable for non-ok reply", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureBootstrapped_RunsOnce(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
if len(args) != 2 || args[0] != "bootstrap" {
|
||||||
|
t.Errorf("unexpected callShim args: %v", args)
|
||||||
|
}
|
||||||
|
// args[1] is the base64'd system prompt — no need to decode in
|
||||||
|
// the test; just sanity-check it isn't trivially empty.
|
||||||
|
if len(args[1]) < 100 {
|
||||||
|
t.Errorf("bootstrap prompt suspiciously short: %d bytes", len(args[1]))
|
||||||
|
}
|
||||||
|
return []byte("ok\n"), nil
|
||||||
|
}
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if err := s.ensureBootstrapped(context.Background()); err != nil {
|
||||||
|
t.Fatalf("ensureBootstrapped iteration %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Errorf("expected 1 callShim call (bootstrap is one-shot); got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureBootstrapped_RetriesOnFailure(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
var failOnce atomic.Bool
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
if failOnce.CompareAndSwap(false, true) {
|
||||||
|
return nil, errors.New("ssh: transient failure")
|
||||||
|
}
|
||||||
|
return []byte("ok\n"), nil
|
||||||
|
}
|
||||||
|
if err := s.ensureBootstrapped(context.Background()); err == nil {
|
||||||
|
t.Fatal("first call should error")
|
||||||
|
}
|
||||||
|
if err := s.ensureBootstrapped(context.Background()); err != nil {
|
||||||
|
t.Fatalf("second call should succeed: %v", err)
|
||||||
|
}
|
||||||
|
// Third call should be a cache hit (bootstrapped flag set on success).
|
||||||
|
if err := s.ensureBootstrapped(context.Background()); err != nil {
|
||||||
|
t.Fatalf("third call should be cached: %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 2 {
|
||||||
|
t.Errorf("expected 2 callShim calls (1 fail + 1 succeed; 3rd cached); got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthGate_CacheExpires(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return []byte("ok"), nil
|
||||||
|
}
|
||||||
|
if err := s.healthGate(context.Background()); err != nil {
|
||||||
|
t.Fatalf("first probe: %v", err)
|
||||||
|
}
|
||||||
|
// Force the cached timestamp to expire.
|
||||||
|
s.healthMu.Lock()
|
||||||
|
s.healthCheckedAt = time.Now().Add(-11 * time.Second)
|
||||||
|
s.healthMu.Unlock()
|
||||||
|
if err := s.healthGate(context.Background()); err != nil {
|
||||||
|
t.Fatalf("second probe (expired cache): %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 2 {
|
||||||
|
t.Errorf("expected 2 callShim calls (cache expired between); got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemotePaliadin_ImplementsPaliadin(t *testing.T) {
|
||||||
|
// Compile-time check is in paliadin_remote.go; this test makes the
|
||||||
|
// failure mode obvious if someone accidentally drops a method.
|
||||||
|
var _ Paliadin = (*RemotePaliadinService)(nil)
|
||||||
|
var _ Paliadin = (*LocalPaliadinService)(nil)
|
||||||
|
var _ Paliadin = (*DisabledPaliadinService)(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisabledPaliadinService(t *testing.T) {
|
||||||
|
s := NewDisabledPaliadinService(nil, nil)
|
||||||
|
if _, err := s.RunTurn(context.Background(), TurnRequest{}); !errors.Is(err, ErrPaliadinDisabled) {
|
||||||
|
t.Errorf("RunTurn error = %v; want ErrPaliadinDisabled", err)
|
||||||
|
}
|
||||||
|
if err := s.ResetSession(context.Background()); !errors.Is(err, ErrPaliadinDisabled) {
|
||||||
|
t.Errorf("ResetSession error = %v; want ErrPaliadinDisabled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallShim_SSHArgvShape(t *testing.T) {
|
||||||
|
// Verify the ssh argv we'd construct includes the bypass-port flag,
|
||||||
|
// the key + known_hosts paths, and the verb after `--`. We don't
|
||||||
|
// actually exec ssh — we set callShimHook so callShim never reaches
|
||||||
|
// the exec path; this test just guards the constructor wiring.
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{
|
||||||
|
SSHHost: "100.99.98.203",
|
||||||
|
SSHPort: 22022,
|
||||||
|
SSHUser: "m",
|
||||||
|
SSHKeyPath: "/tmp/k",
|
||||||
|
KnownHostsPath: "/tmp/kh",
|
||||||
|
})
|
||||||
|
var captured []string
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
captured = append([]string(nil), args...)
|
||||||
|
return []byte("ok"), nil
|
||||||
|
}
|
||||||
|
_, _ = s.callShim(context.Background(), "health")
|
||||||
|
if len(captured) != 1 || captured[0] != "health" {
|
||||||
|
t.Errorf("callShim forwarded args = %v; want [health]", captured)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallShim_StderrSurfacesInError(t *testing.T) {
|
||||||
|
// When the real exec path fails, callShim wraps stderr into the
|
||||||
|
// returned error so classifySSHError can pattern-match. Simulate
|
||||||
|
// that contract via the hook.
|
||||||
|
s := NewRemotePaliadinService(nil, nil, RemotePaliadinConfig{SSHHost: "x"})
|
||||||
|
s.callShimHook = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||||
|
return nil, errors.New("ssh health: exit status 1 (stderr: Permission denied (publickey))")
|
||||||
|
}
|
||||||
|
_, err := s.callShim(context.Background(), "health")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "Permission denied") {
|
||||||
|
t.Errorf("error should preserve stderr: %v", err)
|
||||||
|
}
|
||||||
|
if classifySSHError(err) != "shim_auth_failed" {
|
||||||
|
t.Errorf("classifier should pick up Permission denied; got %q", classifySSHError(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user