package web import ( "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) // fakeSupabase stubs the three /auth/v1 endpoints we touch. type fakeSupabase struct { *httptest.Server ValidAccess string ValidRefresh string NewAccess string NewRefresh string ValidEmail string ValidPass string IssuedAccess string IssuedRefr string } func newFakeSupabase(t *testing.T) *fakeSupabase { t.Helper() f := &fakeSupabase{ ValidAccess: "good-access", ValidRefresh: "good-refresh", NewAccess: "rotated-access", NewRefresh: "rotated-refresh", ValidEmail: "m@example", ValidPass: "correct-horse-battery-staple", IssuedAccess: "issued-access", IssuedRefr: "issued-refresh", } mux := http.NewServeMux() mux.HandleFunc("/auth/v1/user", func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer "+f.ValidAccess { http.Error(w, `{"msg":"invalid token"}`, http.StatusUnauthorized) return } _ = json.NewEncoder(w).Encode(map[string]string{"id": "u-1", "email": f.ValidEmail}) }) mux.HandleFunc("/auth/v1/token", func(w http.ResponseWriter, r *http.Request) { grant := r.URL.Query().Get("grant_type") var body map[string]string _ = json.NewDecoder(r.Body).Decode(&body) switch grant { case "password": if body["email"] != f.ValidEmail || body["password"] != f.ValidPass { http.Error(w, `{"error_description":"Invalid login credentials"}`, http.StatusBadRequest) return } _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": f.IssuedAccess, "refresh_token": f.IssuedRefr, "user": map[string]string{"id": "u-1"}, }) case "refresh_token": if body["refresh_token"] != f.ValidRefresh { http.Error(w, `{"msg":"bad refresh"}`, http.StatusBadRequest) return } _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": f.NewAccess, "refresh_token": f.NewRefresh, "user": map[string]string{"id": "u-1"}, }) default: http.Error(w, "bad grant", http.StatusBadRequest) } }) f.Server = httptest.NewServer(mux) t.Cleanup(f.Server.Close) return f } // gatedMux wires a tiny app behind authMiddleware. It exposes /, /healthz, // /login (always-open), /logout (always-open) so the middleware tests can // exercise the gate without spinning up the real Server. func gatedMux(t *testing.T, supaURL, anon string) http.Handler { t.Helper() mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "tree-page") }) mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "ok") }) mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "login-form") }) mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, clearCookie(accessTokenCookie)) http.SetCookie(w, clearCookie(refreshTokenCookie)) http.Redirect(w, r, "/login", http.StatusFound) }) cfg := AuthConfig{SupabaseURL: supaURL, AnonKey: anon} return authMiddleware(cfg, slog.New(slog.NewTextHandler(io.Discard, nil)), mux) } func TestSafeRedirect(t *testing.T) { cases := map[string]string{ "/i/dev": "/i/dev", "/": "/", "": "", "//evil.com": "", "https://evil.com": "", "javascript:alert": "", `/path\nset-cookie`: "", `\evil`: "", } for in, want := range cases { if got := safeRedirect(in); got != want { t.Errorf("safeRedirect(%q) = %q, want %q", in, got, want) } } } func TestHealthzAlwaysOpen(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") r := httptest.NewRequest(http.MethodGet, "/healthz", nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Result().StatusCode != 200 { t.Fatalf("healthz: %d", w.Result().StatusCode) } } func TestLoginPathBypassesAuth(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") for _, path := range []string{"/login", "/logout"} { r := httptest.NewRequest(http.MethodGet, path, nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) // /login → 200 (the test mux serves a form). /logout → 302 (clears cookies). if w.Result().StatusCode == http.StatusFound && path == "/logout" { continue } if w.Result().StatusCode != 200 { t.Fatalf("%s: status %d", path, w.Result().StatusCode) } } } func TestUnauthedRedirectsToLocalLogin(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") r := httptest.NewRequest(http.MethodGet, "/i/dev", nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Result().StatusCode != http.StatusFound { t.Fatalf("status %d, want 302", w.Result().StatusCode) } loc := w.Header().Get("Location") if !strings.HasPrefix(loc, "/login?") { t.Fatalf("Location = %q, want /login? prefix", loc) } if !strings.Contains(loc, "redirectTo=") { t.Fatalf("Location missing redirectTo: %q", loc) } // Must NOT bounce to another host. if strings.Contains(loc, "msbls.de") || strings.HasPrefix(loc, "http") { t.Fatalf("Location should be relative to projax: %q", loc) } } func TestValidCookieAuthorizes(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: accessTokenCookie, Value: f.ValidAccess}) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Result().StatusCode != 200 { t.Fatalf("status %d", w.Result().StatusCode) } if strings.TrimSpace(w.Body.String()) != "tree-page" { t.Fatalf("body %q", w.Body.String()) } } func TestBearerHeaderAuthorizes(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Authorization", "Bearer "+f.ValidAccess) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Result().StatusCode != 200 { t.Fatalf("status %d", w.Result().StatusCode) } } func TestStaleAccessRefreshesAndIssuesCookies(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: accessTokenCookie, Value: "stale"}) r.AddCookie(&http.Cookie{Name: refreshTokenCookie, Value: f.ValidRefresh}) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Result().StatusCode != 200 { t.Fatalf("status %d", w.Result().StatusCode) } got := map[string]*http.Cookie{} for _, c := range w.Result().Cookies() { got[c.Name] = c } if c := got[accessTokenCookie]; c == nil || c.Value != f.NewAccess { t.Fatalf("access cookie missing or wrong value") } if c := got[refreshTokenCookie]; c == nil || c.Value != f.NewRefresh { t.Fatalf("refresh cookie missing or wrong value") } // Per-host scope: NO Domain attribute. if got[accessTokenCookie].Domain != "" { t.Fatalf("access cookie has Domain=%q, want empty (per-host)", got[accessTokenCookie].Domain) } if got[refreshTokenCookie].Domain != "" { t.Fatalf("refresh cookie has Domain=%q, want empty", got[refreshTokenCookie].Domain) } for _, c := range got { if !c.HttpOnly || !c.Secure || c.SameSite != http.SameSiteLaxMode { t.Errorf("flags wrong for %q: httponly=%v secure=%v samesite=%v", c.Name, c.HttpOnly, c.Secure, c.SameSite) } } } func TestBadRefreshFinallyRedirects(t *testing.T) { f := newFakeSupabase(t) h := gatedMux(t, f.URL, "anon") r := httptest.NewRequest(http.MethodGet, "/", nil) r.AddCookie(&http.Cookie{Name: accessTokenCookie, Value: "stale"}) r.AddCookie(&http.Cookie{Name: refreshTokenCookie, Value: "no-good"}) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Result().StatusCode != http.StatusFound { t.Fatalf("status %d, want 302", w.Result().StatusCode) } } // --- /login + /logout handler tests (use the real Server) --- func makeServerWithStub(t *testing.T, f *fakeSupabase) *Server { t.Helper() srv, err := New(nil, slog.New(slog.NewTextHandler(io.Discard, nil))) if err != nil { t.Fatalf("server: %v", err) } srv.Auth = &AuthConfig{SupabaseURL: f.URL, AnonKey: "anon"} return srv } func TestLoginGETRendersForm(t *testing.T) { f := newFakeSupabase(t) srv := makeServerWithStub(t, f) r := httptest.NewRequest(http.MethodGet, "/login?redirectTo=/i/dev", nil) w := httptest.NewRecorder() srv.handleLoginForm(w, r) if w.Result().StatusCode != 200 { t.Fatalf("status %d", w.Result().StatusCode) } body := w.Body.String() if !strings.Contains(body, `name="email"`) || !strings.Contains(body, `name="password"`) { t.Errorf("body missing email/password fields") } if !strings.Contains(body, `name="redirectTo" value="/i/dev"`) { t.Errorf("body missing redirectTo hidden input") } } func TestLoginGETShortCircuitsWhenAlreadySignedIn(t *testing.T) { f := newFakeSupabase(t) srv := makeServerWithStub(t, f) r := httptest.NewRequest(http.MethodGet, "/login?redirectTo=/i/dev", nil) r.AddCookie(&http.Cookie{Name: accessTokenCookie, Value: f.ValidAccess}) w := httptest.NewRecorder() srv.handleLoginForm(w, r) if w.Result().StatusCode != http.StatusFound { t.Fatalf("status %d, want 302", w.Result().StatusCode) } if loc := w.Header().Get("Location"); loc != "/i/dev" { t.Errorf("Location = %q, want /i/dev", loc) } } func TestLoginPOSTSuccessSetsCookiesAndRedirects(t *testing.T) { f := newFakeSupabase(t) srv := makeServerWithStub(t, f) form := url.Values{} form.Set("email", f.ValidEmail) form.Set("password", f.ValidPass) form.Set("redirectTo", "/i/dev") r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() srv.handleLoginSubmit(w, r) if w.Result().StatusCode != http.StatusFound { body, _ := io.ReadAll(w.Result().Body) t.Fatalf("status %d body=%s", w.Result().StatusCode, body) } if loc := w.Header().Get("Location"); loc != "/i/dev" { t.Errorf("Location = %q, want /i/dev", loc) } var sawAccess, sawRefresh bool for _, c := range w.Result().Cookies() { if c.Name == accessTokenCookie { sawAccess = true if c.Domain != "" { t.Errorf("access cookie has Domain=%q, want empty", c.Domain) } if c.Value != f.IssuedAccess { t.Errorf("access cookie value %q, want %q", c.Value, f.IssuedAccess) } } if c.Name == refreshTokenCookie { sawRefresh = true } } if !sawAccess || !sawRefresh { t.Errorf("missing session cookies after login") } } func TestLoginPOSTBadCredsRerendersWithError(t *testing.T) { f := newFakeSupabase(t) srv := makeServerWithStub(t, f) form := url.Values{} form.Set("email", f.ValidEmail) form.Set("password", "wrong") r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() srv.handleLoginSubmit(w, r) if w.Result().StatusCode != http.StatusUnauthorized { t.Fatalf("status %d, want 401", w.Result().StatusCode) } if !strings.Contains(w.Body.String(), "Invalid login credentials") { t.Errorf("form did not surface error message: %q", w.Body.String()) } } func TestLoginRedirectToRejectedWhenUnsafe(t *testing.T) { f := newFakeSupabase(t) srv := makeServerWithStub(t, f) for _, hostile := range []string{"//evil.com", "https://evil.com", `\evil`} { form := url.Values{} form.Set("email", f.ValidEmail) form.Set("password", f.ValidPass) form.Set("redirectTo", hostile) r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() srv.handleLoginSubmit(w, r) if loc := w.Header().Get("Location"); loc != "/" { t.Errorf("hostile redirectTo %q -> Location %q, want /", hostile, loc) } } } func TestLogoutClearsCookies(t *testing.T) { f := newFakeSupabase(t) srv := makeServerWithStub(t, f) r := httptest.NewRequest(http.MethodPost, "/logout", nil) w := httptest.NewRecorder() srv.handleLogout(w, r) if w.Result().StatusCode != http.StatusFound { t.Fatalf("status %d, want 302", w.Result().StatusCode) } if loc := w.Header().Get("Location"); loc != "/login" { t.Errorf("Location = %q, want /login", loc) } cleared := 0 for _, c := range w.Result().Cookies() { if c.Name == accessTokenCookie || c.Name == refreshTokenCookie { if c.MaxAge >= 0 || c.Value != "" { t.Errorf("cookie %q not cleared: maxAge=%d value=%q", c.Name, c.MaxAge, c.Value) } cleared++ } } if cleared != 2 { t.Errorf("cleared %d cookies, want 2", cleared) } }