From 43d7fffeefcf2867740ab2ee7637fa33b10c50c3 Mon Sep 17 00:00:00 2001 From: rcourtman Date: Mon, 2 Feb 2026 22:02:11 +0000 Subject: [PATCH] Test: add coverage for auth and security handlers Add additional tests for OIDC, SAML, and tenant middleware to improve coverage of security-critical paths. --- .../api/middleware_tenant_additional_test.go | 139 ++++++++ internal/api/oidc_handlers_more_test.go | 318 ++++++++++++++++++ internal/api/router_auth_additional_test.go | 161 +++++++++ internal/api/saml_handlers_more_test.go | 217 ++++++++++++ internal/api/saml_service_additional_test.go | 133 ++++++++ .../security_oidc_handlers_additional_test.go | 95 ++++++ 6 files changed, 1063 insertions(+) create mode 100644 internal/api/middleware_tenant_additional_test.go create mode 100644 internal/api/oidc_handlers_more_test.go create mode 100644 internal/api/router_auth_additional_test.go create mode 100644 internal/api/saml_handlers_more_test.go create mode 100644 internal/api/saml_service_additional_test.go create mode 100644 internal/api/security_oidc_handlers_additional_test.go diff --git a/internal/api/middleware_tenant_additional_test.go b/internal/api/middleware_tenant_additional_test.go new file mode 100644 index 000000000..b4de3a153 --- /dev/null +++ b/internal/api/middleware_tenant_additional_test.go @@ -0,0 +1,139 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestNewTenantMiddlewareWithConfig(t *testing.T) { + persistence := config.NewMultiTenantPersistence(t.TempDir()) + checker := stubAuthorizationChecker{} + + mw := NewTenantMiddlewareWithConfig(TenantMiddlewareConfig{ + Persistence: persistence, + AuthChecker: checker, + }) + + if mw.persistence != persistence { + t.Fatalf("expected persistence to be set") + } + if mw.authChecker == nil { + t.Fatalf("expected auth checker to be set") + } +} + +type stubAuthorizationChecker struct{} + +func (stubAuthorizationChecker) TokenCanAccessOrg(*config.APITokenRecord, string) bool { + return true +} + +func (stubAuthorizationChecker) UserCanAccessOrg(string, string) bool { + return true +} + +func (stubAuthorizationChecker) CheckAccess(*config.APITokenRecord, string, string) AuthorizationResult { + return AuthorizationResult{Allowed: true} +} + +func TestWriteJSONError(t *testing.T) { + rec := httptest.NewRecorder() + writeJSONError(rec, http.StatusBadRequest, "bad", "message") + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rec.Code) + } + + var payload map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["error"] != "bad" || payload["message"] != "message" { + t.Fatalf("unexpected payload: %+v", payload) + } +} + +func TestTenantMiddleware_OrgExtraction(t *testing.T) { + mw := NewTenantMiddleware(nil) + + handler := mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orgID := GetOrgID(r.Context()) + org := GetOrganization(r.Context()) + if orgID == "" || org == nil || org.ID != orgID { + t.Fatalf("unexpected org context: %q %+v", orgID, org) + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Pulse-Org-ID", "header-org") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "pulse_org_id", Value: "cookie-org"}) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) +} + +func TestTenantMiddleware_InvalidOrg(t *testing.T) { + persistence := config.NewMultiTenantPersistence(t.TempDir()) + mw := NewTenantMiddleware(persistence) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Pulse-Org-ID", "../bad") + rec := httptest.NewRecorder() + + mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("next handler should not be called") + })).ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rec.Code) + } +} + +func TestTenantMiddleware_MultiTenantDisabled(t *testing.T) { + orig := IsMultiTenantEnabled() + SetMultiTenantEnabled(false) + t.Cleanup(func() { SetMultiTenantEnabled(orig) }) + + mw := NewTenantMiddleware(nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Pulse-Org-ID", "tenant-1") + rec := httptest.NewRecorder() + + mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("next handler should not be called") + })).ServeHTTP(rec, req) + + if rec.Code != http.StatusNotImplemented { + t.Fatalf("expected 501, got %d", rec.Code) + } +} + +func TestTenantMiddleware_MultiTenantLicenseRequired(t *testing.T) { + orig := IsMultiTenantEnabled() + SetMultiTenantEnabled(true) + t.Cleanup(func() { SetMultiTenantEnabled(orig) }) + + mw := NewTenantMiddleware(nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Pulse-Org-ID", "tenant-2") + rec := httptest.NewRecorder() + + mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("next handler should not be called") + })).ServeHTTP(rec, req) + + if rec.Code != http.StatusPaymentRequired { + t.Fatalf("expected 402, got %d", rec.Code) + } +} diff --git a/internal/api/oidc_handlers_more_test.go b/internal/api/oidc_handlers_more_test.go new file mode 100644 index 000000000..0295aecd0 --- /dev/null +++ b/internal/api/oidc_handlers_more_test.go @@ -0,0 +1,318 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "golang.org/x/oauth2" +) + +func newTestOIDCConfig() *config.OIDCConfig { + return &config.OIDCConfig{ + Enabled: true, + IssuerURL: "https://issuer.example.com", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "https://app.example.com/api/oidc/callback", + Scopes: []string{"openid", "email"}, + UsernameClaim: "preferred_username", + EmailClaim: "email", + GroupsClaim: "groups", + } +} + +func newTestOIDCService(cfg *config.OIDCConfig, authURL, tokenURL string) *OIDCService { + return &OIDCService{ + snapshot: oidcSnapshot{ + issuer: cfg.IssuerURL, + clientID: cfg.ClientID, + clientSecret: cfg.ClientSecret, + redirectURL: cfg.RedirectURL, + scopes: append([]string{}, cfg.Scopes...), + caBundle: cfg.CABundle, + }, + oauth2Cfg: &oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + RedirectURL: cfg.RedirectURL, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + }, + Scopes: append([]string{}, cfg.Scopes...), + }, + stateStore: newOIDCStateStore(), + } +} + +func newOIDCRouterWithService(t *testing.T, authURL, tokenURL string) (*Router, *OIDCService) { + t.Helper() + cfg := newTestOIDCConfig() + svc := newTestOIDCService(cfg, authURL, tokenURL) + router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc} + t.Cleanup(func() { + if svc.stateStore != nil { + svc.stateStore.Stop() + } + }) + return router, svc +} + +func TestHandleOIDCLogin_MethodNotAllowed(t *testing.T) { + router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}} + req := httptest.NewRequest(http.MethodPut, "/api/oidc/login", nil) + rec := httptest.NewRecorder() + + router.handleOIDCLogin(rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed) + } +} + +func TestHandleOIDCLogin_InvalidJSON(t *testing.T) { + router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", strings.NewReader("{")) + rec := httptest.NewRecorder() + + router.handleOIDCLogin(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + var payload map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["code"] != "invalid_request" { + t.Fatalf("code = %v, want invalid_request", payload["code"]) + } +} + +func TestHandleOIDCLogin_GetSuccess(t *testing.T) { + router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + req := httptest.NewRequest(http.MethodGet, "/api/oidc/login?returnTo=/dashboard", nil) + rec := httptest.NewRecorder() + + router.handleOIDCLogin(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + u, err := url.Parse(location) + if err != nil { + t.Fatalf("parse location: %v", err) + } + if u.Host != "auth.example.com" { + t.Fatalf("unexpected auth host: %q", u.Host) + } + state := u.Query().Get("state") + if state == "" { + t.Fatalf("expected state param in redirect") + } + entry, ok := svc.consumeState(state) + if !ok { + t.Fatalf("expected state entry to be stored") + } + if entry.ReturnTo != "/dashboard" { + t.Fatalf("returnTo = %q, want /dashboard", entry.ReturnTo) + } +} + +func TestHandleOIDCLogin_PostSuccess(t *testing.T) { + router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + body := strings.NewReader(`{"returnTo":"/home"}`) + req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", body) + rec := httptest.NewRecorder() + + router.handleOIDCLogin(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + var payload struct { + AuthorizationURL string `json:"authorizationUrl"` + } + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload.AuthorizationURL == "" { + t.Fatalf("expected authorizationUrl in response") + } + u, err := url.Parse(payload.AuthorizationURL) + if err != nil { + t.Fatalf("parse authorizationUrl: %v", err) + } + state := u.Query().Get("state") + if state == "" { + t.Fatalf("expected state param in authorizationUrl") + } + entry, ok := svc.consumeState(state) + if !ok { + t.Fatalf("expected state entry to be stored") + } + if entry.ReturnTo != "/home" { + t.Fatalf("returnTo = %q, want /home", entry.ReturnTo) + } +} + +func TestGetOIDCService_ReturnsCachedService(t *testing.T) { + cfg := newTestOIDCConfig() + svc := newTestOIDCService(cfg, "https://auth.example.com/authorize", "https://token.example.com") + router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc} + defer svc.stateStore.Stop() + + got, err := router.getOIDCService(context.Background(), cfg.RedirectURL) + if err != nil { + t.Fatalf("getOIDCService error: %v", err) + } + if got != svc { + t.Fatalf("expected cached service to be returned") + } +} + +func TestHandleOIDCCallback_MethodNotAllowed(t *testing.T) { + router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}} + req := httptest.NewRequest(http.MethodPost, "/api/oidc/callback", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed) + } +} + +func TestHandleOIDCCallback_ErrorParam(t *testing.T) { + router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?error=access_denied", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc_error=access_denied") { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestHandleOIDCCallback_MissingState(t *testing.T) { + router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?code=abc", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc_error=missing_state") { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestHandleOIDCCallback_InvalidState(t *testing.T) { + router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state=invalid", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc_error=invalid_state") { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestHandleOIDCCallback_MissingCode(t *testing.T) { + router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "") + state, _, err := svc.newStateEntry("/dashboard") + if err != nil { + t.Fatalf("newStateEntry error: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state), nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc_error=missing_code") { + t.Fatalf("unexpected redirect location: %q", location) + } + if !strings.HasPrefix(location, "/dashboard") { + t.Fatalf("expected redirect back to /dashboard, got %q", location) + } +} + +func TestHandleOIDCCallback_ExchangeFailed(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer tokenServer.Close() + + router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL) + svc.httpClient = tokenServer.Client() + state, _, err := svc.newStateEntry("/dashboard") + if err != nil { + t.Fatalf("newStateEntry error: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc_error=exchange_failed") { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestHandleOIDCCallback_MissingIDToken(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"access","token_type":"Bearer","expires_in":3600}`)) + })) + defer tokenServer.Close() + + router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL) + svc.httpClient = tokenServer.Client() + state, _, err := svc.newStateEntry("/dashboard") + if err != nil { + t.Fatalf("newStateEntry error: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc_error=missing_id_token") { + t.Fatalf("unexpected redirect location: %q", location) + } +} diff --git a/internal/api/router_auth_additional_test.go b/internal/api/router_auth_additional_test.go new file mode 100644 index 000000000..ceea12745 --- /dev/null +++ b/internal/api/router_auth_additional_test.go @@ -0,0 +1,161 @@ +package api + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/pkg/auth" +) + +func newAuthRouter(t *testing.T) *Router { + t.Helper() + hashed, err := auth.HashPassword("currentpassword") + if err != nil { + t.Fatalf("hash password: %v", err) + } + return &Router{ + config: &config.Config{ + AuthUser: "admin", + AuthPass: hashed, + ConfigPath: t.TempDir(), + }, + } +} + +func TestHandleChangePassword_InvalidJSON(t *testing.T) { + router := newAuthRouter(t) + req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader("{")) + rec := httptest.NewRecorder() + + router.handleChangePassword(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + var payload map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["code"] != "invalid_request" { + t.Fatalf("expected invalid_request, got %#v", payload["code"]) + } +} + +func TestHandleChangePassword_InvalidPassword(t *testing.T) { + router := newAuthRouter(t) + body := `{"currentPassword":"currentpassword","newPassword":"short"}` + req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body)) + rec := httptest.NewRecorder() + + router.handleChangePassword(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + var payload map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["code"] != "invalid_password" { + t.Fatalf("expected invalid_password, got %#v", payload["code"]) + } +} + +func TestHandleChangePassword_MissingCurrent(t *testing.T) { + router := newAuthRouter(t) + body := `{"currentPassword":"","newPassword":"newpassword123"}` + req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body)) + rec := httptest.NewRecorder() + + router.handleChangePassword(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + var payload map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["code"] != "unauthorized" { + t.Fatalf("expected unauthorized, got %#v", payload["code"]) + } +} + +func TestHandleChangePassword_SuccessDocker(t *testing.T) { + router := newAuthRouter(t) + t.Setenv("PULSE_DOCKER", "true") + + body := `{"currentPassword":"currentpassword","newPassword":"newpassword123"}` + req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body)) + rec := httptest.NewRecorder() + + router.handleChangePassword(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + var payload map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if ok, _ := payload["success"].(bool); !ok { + t.Fatalf("expected success=true, got %#v", payload["success"]) + } + envPath := filepath.Join(router.config.ConfigPath, ".env") + if _, err := os.Stat(envPath); err != nil { + t.Fatalf("expected .env to be written, got error: %v", err) + } +} + +func TestHandleResetLockout_InvalidJSON(t *testing.T) { + router := newAuthRouter(t) + req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader("{")) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword"))) + rec := httptest.NewRecorder() + + router.handleResetLockout(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestHandleResetLockout_MissingIdentifier(t *testing.T) { + router := newAuthRouter(t) + req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":""}`)) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword"))) + rec := httptest.NewRecorder() + + router.handleResetLockout(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestHandleResetLockout_Success(t *testing.T) { + router := newAuthRouter(t) + req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":"user1"}`)) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword"))) + rec := httptest.NewRecorder() + + router.handleResetLockout(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + var payload map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if ok, _ := payload["success"].(bool); !ok { + t.Fatalf("expected success=true, got %#v", payload["success"]) + } +} diff --git a/internal/api/saml_handlers_more_test.go b/internal/api/saml_handlers_more_test.go new file mode 100644 index 000000000..0ffd08acc --- /dev/null +++ b/internal/api/saml_handlers_more_test.go @@ -0,0 +1,217 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func newTestSAMLService(t *testing.T, providerID string, metadataXML string) *SAMLService { + t.Helper() + service, err := NewSAMLService(context.Background(), providerID, &config.SAMLProviderConfig{ + IDPMetadataXML: metadataXML, + }, "https://pulse.example.com") + if err != nil { + t.Fatalf("NewSAMLService: %v", err) + } + return service +} + +func TestHandleSAMLACS_ProcessResponseError(t *testing.T) { + router := newSAMLRouter(t, testSAMLProvider("okta", true)) + router.samlManager.services["okta"] = &SAMLService{} + + req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil) + rec := httptest.NewRecorder() + + router.handleSAMLACS(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code) + } + if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=saml_validation_failed") { + t.Fatalf("expected validation failed redirect, got %q", loc) + } +} + +func TestHandleSAMLMetadata_InvalidMethod(t *testing.T) { + router := newSAMLRouter(t, testSAMLProvider("okta", true)) + req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/metadata", nil) + rec := httptest.NewRecorder() + + router.handleSAMLMetadata(rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code) + } +} + +func TestHandleSAMLMetadata_InvalidProviderID(t *testing.T) { + router := newSAMLRouter(t, testSAMLProvider("okta", true)) + req := httptest.NewRequest(http.MethodGet, "/api/saml/invalid$id/metadata", nil) + rec := httptest.NewRecorder() + + router.handleSAMLMetadata(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } +} + +func TestGetSAMLSessionInfo_NoCookie(t *testing.T) { + router := &Router{} + req := httptest.NewRequest(http.MethodGet, "/", nil) + + if info := router.getSAMLSessionInfo(req); info != nil { + t.Fatalf("expected nil session info without cookie") + } +} + +func TestGetSAMLSessionInfo_ReturnsInfo(t *testing.T) { + InitSessionStore(t.TempDir()) + + token := generateSessionToken() + GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{ + ProviderID: "okta", + NameID: "name-id", + SessionIndex: "sess-1", + }) + + router := &Router{} + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token}) + + info := router.getSAMLSessionInfo(req) + if info == nil { + t.Fatalf("expected session info") + } + if info.ProviderID != "okta" || info.NameID != "name-id" || info.SessionIndex != "sess-1" { + t.Fatalf("unexpected session info: %#v", info) + } +} + +func TestClearSession(t *testing.T) { + router := &Router{} + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + router.clearSession(rec, req) + + cookies := rec.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("expected 1 cookie, got %d", len(cookies)) + } + cookie := cookies[0] + if cookie.Name != "session" { + t.Fatalf("expected session cookie name, got %q", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Fatalf("expected MaxAge -1, got %d", cookie.MaxAge) + } + if !cookie.HttpOnly { + t.Fatalf("expected HttpOnly cookie") + } +} + +func TestHandleSAMLSLO_Redirects(t *testing.T) { + router := &Router{} + req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/slo", nil) + rec := httptest.NewRecorder() + + router.handleSAMLSLO(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/?logout=success" { + t.Fatalf("unexpected redirect location %q", loc) + } +} + +func TestHandleSAMLLogout_SLOUnavailable(t *testing.T) { + InitSessionStore(t.TempDir()) + + router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")} + metadataXML := ` + + + + +` + router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML) + + token := generateSessionToken() + GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{ + ProviderID: "okta", + NameID: "name-id", + SessionIndex: "sess-1", + }) + + req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil) + req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token}) + rec := httptest.NewRecorder() + + router.handleSAMLLogout(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/?logout=success" { + t.Fatalf("unexpected redirect location %q", loc) + } +} + +func TestHandleSAMLLogout_SLOSuccess(t *testing.T) { + InitSessionStore(t.TempDir()) + + router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")} + metadataXML := ` + + + + + +` + router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML) + + token := generateSessionToken() + GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{ + ProviderID: "okta", + NameID: "name-id", + SessionIndex: "sess-1", + }) + + req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil) + req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token}) + rec := httptest.NewRecorder() + + router.handleSAMLLogout(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.Contains(loc, "https://idp.example.com/slo") || !strings.Contains(loc, "SAMLRequest=") { + t.Fatalf("unexpected SLO redirect location %q", loc) + } +} + +func TestExtractSAMLProviderID(t *testing.T) { + if got := extractSAMLProviderID("/api/saml/okta/login", "login"); got != "okta" { + t.Fatalf("expected okta, got %q", got) + } + if got := extractSAMLProviderID("/api/saml/okta/logout", "login"); got != "" { + t.Fatalf("expected empty provider, got %q", got) + } + if got := extractSAMLProviderID("/api/saml/okta/login/extra", "login"); got != "okta" { + t.Fatalf("expected okta for extra path, got %q", got) + } + if got := extractSAMLProviderID("/api/other/okta/login", "login"); got != "" { + t.Fatalf("expected empty provider for non-saml path, got %q", got) + } +} diff --git a/internal/api/saml_service_additional_test.go b/internal/api/saml_service_additional_test.go new file mode 100644 index 000000000..9c85987d8 --- /dev/null +++ b/internal/api/saml_service_additional_test.go @@ -0,0 +1,133 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/crewjam/saml" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestParseIDPMetadataXML_EmptyEntities(t *testing.T) { + wrapped := ` + +` + if _, err := parseIDPMetadataXML([]byte(wrapped)); err == nil { + t.Fatal("expected error for empty entities descriptor") + } +} + +func TestExtractAttribute(t *testing.T) { + service := &SAMLService{} + attrs := map[string][]string{ + "email": {"user@example.com"}, + } + + if got := service.extractAttribute(attrs, "", "fallback"); got != "fallback" { + t.Fatalf("expected fallback value, got %q", got) + } + if got := service.extractAttribute(attrs, "email", ""); got != "user@example.com" { + t.Fatalf("unexpected attribute value: %q", got) + } + if got := service.extractAttribute(attrs, "missing", "default"); got != "default" { + t.Fatalf("unexpected missing attribute value: %q", got) + } +} + +func TestProcessResponse_InvalidResponse(t *testing.T) { + service := &SAMLService{ + sp: &saml.ServiceProvider{}, + } + + body := strings.NewReader("RelayState=/dashboard&SAMLResponse=") + req := httptest.NewRequest(http.MethodPost, "/acs", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + _, relay, err := service.ProcessResponse(req) + if err == nil { + t.Fatal("expected error for invalid response") + } + if relay != "/dashboard" { + t.Fatalf("unexpected relay state: %q", relay) + } +} + +func TestSAMLServiceIdentifiers(t *testing.T) { + service := &SAMLService{ + providerID: "provider-1", + sp: &saml.ServiceProvider{EntityID: "sp-entity"}, + idpMetadata: &saml.EntityDescriptor{ + EntityID: "idp-entity", + }, + } + + if service.ProviderID() != "provider-1" { + t.Fatalf("unexpected provider id") + } + if service.GetSPEntityID() != "sp-entity" { + t.Fatalf("unexpected sp entity id") + } + if service.GetIDPEntityID() != "idp-entity" { + t.Fatalf("unexpected idp entity id") + } + + service.sp = nil + if service.GetSPEntityID() != "" { + t.Fatalf("expected empty sp entity id when nil") + } + service.idpMetadata = nil + if service.GetIDPEntityID() != "" { + t.Fatalf("expected empty idp entity id when nil") + } +} + +func TestRefreshMetadata_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(` + + +`)) + })) + defer server.Close() + + cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL} + service := &SAMLService{ + providerID: "idp", + config: cfg, + baseURL: "http://localhost", + httpClient: server.Client(), + } + + if err := service.RefreshMetadata(context.Background()); err != nil { + t.Fatalf("refresh metadata: %v", err) + } + if service.sp == nil { + t.Fatal("expected service provider to be initialized") + } +} + +func TestFetchIDPMetadataFromURL_NonOK(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + service := &SAMLService{httpClient: server.Client()} + if _, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL); err == nil { + t.Fatal("expected error for non-200 status") + } +} + +func TestAddIDPCertificate_InvalidPEM(t *testing.T) { + service := &SAMLService{config: &config.SAMLProviderConfig{IDPCertificate: "not-pem"}} + metadata := &saml.EntityDescriptor{ + IDPSSODescriptors: []saml.IDPSSODescriptor{{}}, + } + if err := service.addIDPCertificate(metadata); err == nil { + t.Fatal("expected error for invalid certificate pem") + } +} diff --git a/internal/api/security_oidc_handlers_additional_test.go b/internal/api/security_oidc_handlers_additional_test.go new file mode 100644 index 000000000..072f1c841 --- /dev/null +++ b/internal/api/security_oidc_handlers_additional_test.go @@ -0,0 +1,95 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestSecurityOIDCHandlers_GetConfig(t *testing.T) { + cfg := &config.Config{ + PublicURL: "https://pulse.example.com", + OIDC: &config.OIDCConfig{ + Enabled: true, + IssuerURL: "https://issuer.example.com", + ClientID: "client-id", + ClientSecret: "super-secret", + RedirectURL: "https://pulse.example.com/oidc/callback", + Scopes: []string{"openid"}, + UsernameClaim: "sub", + EmailClaim: "email", + GroupsClaim: "groups", + }, + } + router := &Router{config: cfg} + + req := httptest.NewRequest(http.MethodGet, "/api/security/oidc", nil) + rr := httptest.NewRecorder() + + router.handleOIDCConfig(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } + + var resp oidcResponse + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if !resp.Enabled { + t.Fatalf("expected OIDC to be enabled") + } + if resp.IssuerURL != cfg.OIDC.IssuerURL { + t.Fatalf("unexpected issuer url") + } + if !resp.ClientSecretSet { + t.Fatalf("expected client secret to be marked as set") + } +} + +func TestSecurityOIDCHandlers_UpdateSaveFailure(t *testing.T) { + cfg := &config.Config{ + PublicURL: "https://pulse.example.com", + OIDC: &config.OIDCConfig{ + ClientSecret: "old-secret", + CABundle: "old-bundle", + }, + } + router := &Router{config: cfg} + + payload := map[string]any{ + "enabled": true, + "issuerUrl": "https://issuer.example.com", + "clientId": "client-id", + "clientSecret": "new-secret", + "redirectUrl": "https://pulse.example.com/oidc/callback", + "scopes": []string{"openid"}, + "usernameClaim": "sub", + "emailClaim": "email", + "groupsClaim": "groups", + "clearClientSecret": true, + "caBundle": "new-bundle", + } + body, _ := json.Marshal(payload) + + req := httptest.NewRequest(http.MethodPut, "/api/security/oidc", bytes.NewReader(body)) + rr := httptest.NewRecorder() + + router.handleOIDCConfig(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rr.Code) + } + + var apiErr APIError + if err := json.NewDecoder(rr.Body).Decode(&apiErr); err != nil { + t.Fatalf("decode error response: %v", err) + } + if apiErr.Code != "save_failed" { + t.Fatalf("unexpected error code: %s", apiErr.Code) + } +}