diff --git a/internal/api/middleware_tenant_additional_test.go b/internal/api/middleware_tenant_additional_test.go
new file mode 100644
index 000000000..b4de3a153
--- /dev/null
+++ b/internal/api/middleware_tenant_additional_test.go
@@ -0,0 +1,139 @@
+package api
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/rcourtman/pulse-go-rewrite/internal/config"
+)
+
+func TestNewTenantMiddlewareWithConfig(t *testing.T) {
+ persistence := config.NewMultiTenantPersistence(t.TempDir())
+ checker := stubAuthorizationChecker{}
+
+ mw := NewTenantMiddlewareWithConfig(TenantMiddlewareConfig{
+ Persistence: persistence,
+ AuthChecker: checker,
+ })
+
+ if mw.persistence != persistence {
+ t.Fatalf("expected persistence to be set")
+ }
+ if mw.authChecker == nil {
+ t.Fatalf("expected auth checker to be set")
+ }
+}
+
+type stubAuthorizationChecker struct{}
+
+func (stubAuthorizationChecker) TokenCanAccessOrg(*config.APITokenRecord, string) bool {
+ return true
+}
+
+func (stubAuthorizationChecker) UserCanAccessOrg(string, string) bool {
+ return true
+}
+
+func (stubAuthorizationChecker) CheckAccess(*config.APITokenRecord, string, string) AuthorizationResult {
+ return AuthorizationResult{Allowed: true}
+}
+
+func TestWriteJSONError(t *testing.T) {
+ rec := httptest.NewRecorder()
+ writeJSONError(rec, http.StatusBadRequest, "bad", "message")
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("expected 400, got %d", rec.Code)
+ }
+
+ var payload map[string]string
+ if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if payload["error"] != "bad" || payload["message"] != "message" {
+ t.Fatalf("unexpected payload: %+v", payload)
+ }
+}
+
+func TestTenantMiddleware_OrgExtraction(t *testing.T) {
+ mw := NewTenantMiddleware(nil)
+
+ handler := mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ orgID := GetOrgID(r.Context())
+ org := GetOrganization(r.Context())
+ if orgID == "" || org == nil || org.ID != orgID {
+ t.Fatalf("unexpected org context: %q %+v", orgID, org)
+ }
+ }))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("X-Pulse-Org-ID", "header-org")
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+
+ req = httptest.NewRequest(http.MethodGet, "/", nil)
+ req.AddCookie(&http.Cookie{Name: "pulse_org_id", Value: "cookie-org"})
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+
+ req = httptest.NewRequest(http.MethodGet, "/", nil)
+ rec = httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+}
+
+func TestTenantMiddleware_InvalidOrg(t *testing.T) {
+ persistence := config.NewMultiTenantPersistence(t.TempDir())
+ mw := NewTenantMiddleware(persistence)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("X-Pulse-Org-ID", "../bad")
+ rec := httptest.NewRecorder()
+
+ mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Fatalf("next handler should not be called")
+ })).ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("expected 400, got %d", rec.Code)
+ }
+}
+
+func TestTenantMiddleware_MultiTenantDisabled(t *testing.T) {
+ orig := IsMultiTenantEnabled()
+ SetMultiTenantEnabled(false)
+ t.Cleanup(func() { SetMultiTenantEnabled(orig) })
+
+ mw := NewTenantMiddleware(nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("X-Pulse-Org-ID", "tenant-1")
+ rec := httptest.NewRecorder()
+
+ mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Fatalf("next handler should not be called")
+ })).ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusNotImplemented {
+ t.Fatalf("expected 501, got %d", rec.Code)
+ }
+}
+
+func TestTenantMiddleware_MultiTenantLicenseRequired(t *testing.T) {
+ orig := IsMultiTenantEnabled()
+ SetMultiTenantEnabled(true)
+ t.Cleanup(func() { SetMultiTenantEnabled(orig) })
+
+ mw := NewTenantMiddleware(nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("X-Pulse-Org-ID", "tenant-2")
+ rec := httptest.NewRecorder()
+
+ mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Fatalf("next handler should not be called")
+ })).ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusPaymentRequired {
+ t.Fatalf("expected 402, got %d", rec.Code)
+ }
+}
diff --git a/internal/api/oidc_handlers_more_test.go b/internal/api/oidc_handlers_more_test.go
new file mode 100644
index 000000000..0295aecd0
--- /dev/null
+++ b/internal/api/oidc_handlers_more_test.go
@@ -0,0 +1,318 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/rcourtman/pulse-go-rewrite/internal/config"
+ "golang.org/x/oauth2"
+)
+
+func newTestOIDCConfig() *config.OIDCConfig {
+ return &config.OIDCConfig{
+ Enabled: true,
+ IssuerURL: "https://issuer.example.com",
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURL: "https://app.example.com/api/oidc/callback",
+ Scopes: []string{"openid", "email"},
+ UsernameClaim: "preferred_username",
+ EmailClaim: "email",
+ GroupsClaim: "groups",
+ }
+}
+
+func newTestOIDCService(cfg *config.OIDCConfig, authURL, tokenURL string) *OIDCService {
+ return &OIDCService{
+ snapshot: oidcSnapshot{
+ issuer: cfg.IssuerURL,
+ clientID: cfg.ClientID,
+ clientSecret: cfg.ClientSecret,
+ redirectURL: cfg.RedirectURL,
+ scopes: append([]string{}, cfg.Scopes...),
+ caBundle: cfg.CABundle,
+ },
+ oauth2Cfg: &oauth2.Config{
+ ClientID: cfg.ClientID,
+ ClientSecret: cfg.ClientSecret,
+ RedirectURL: cfg.RedirectURL,
+ Endpoint: oauth2.Endpoint{
+ AuthURL: authURL,
+ TokenURL: tokenURL,
+ },
+ Scopes: append([]string{}, cfg.Scopes...),
+ },
+ stateStore: newOIDCStateStore(),
+ }
+}
+
+func newOIDCRouterWithService(t *testing.T, authURL, tokenURL string) (*Router, *OIDCService) {
+ t.Helper()
+ cfg := newTestOIDCConfig()
+ svc := newTestOIDCService(cfg, authURL, tokenURL)
+ router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
+ t.Cleanup(func() {
+ if svc.stateStore != nil {
+ svc.stateStore.Stop()
+ }
+ })
+ return router, svc
+}
+
+func TestHandleOIDCLogin_MethodNotAllowed(t *testing.T) {
+ router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
+ req := httptest.NewRequest(http.MethodPut, "/api/oidc/login", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCLogin(rec, req)
+
+ if rec.Code != http.StatusMethodNotAllowed {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
+ }
+}
+
+func TestHandleOIDCLogin_InvalidJSON(t *testing.T) {
+ router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", strings.NewReader("{"))
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCLogin(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
+ }
+ var payload map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if payload["code"] != "invalid_request" {
+ t.Fatalf("code = %v, want invalid_request", payload["code"])
+ }
+}
+
+func TestHandleOIDCLogin_GetSuccess(t *testing.T) {
+ router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/login?returnTo=/dashboard", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCLogin(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ u, err := url.Parse(location)
+ if err != nil {
+ t.Fatalf("parse location: %v", err)
+ }
+ if u.Host != "auth.example.com" {
+ t.Fatalf("unexpected auth host: %q", u.Host)
+ }
+ state := u.Query().Get("state")
+ if state == "" {
+ t.Fatalf("expected state param in redirect")
+ }
+ entry, ok := svc.consumeState(state)
+ if !ok {
+ t.Fatalf("expected state entry to be stored")
+ }
+ if entry.ReturnTo != "/dashboard" {
+ t.Fatalf("returnTo = %q, want /dashboard", entry.ReturnTo)
+ }
+}
+
+func TestHandleOIDCLogin_PostSuccess(t *testing.T) {
+ router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ body := strings.NewReader(`{"returnTo":"/home"}`)
+ req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", body)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCLogin(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+ var payload struct {
+ AuthorizationURL string `json:"authorizationUrl"`
+ }
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if payload.AuthorizationURL == "" {
+ t.Fatalf("expected authorizationUrl in response")
+ }
+ u, err := url.Parse(payload.AuthorizationURL)
+ if err != nil {
+ t.Fatalf("parse authorizationUrl: %v", err)
+ }
+ state := u.Query().Get("state")
+ if state == "" {
+ t.Fatalf("expected state param in authorizationUrl")
+ }
+ entry, ok := svc.consumeState(state)
+ if !ok {
+ t.Fatalf("expected state entry to be stored")
+ }
+ if entry.ReturnTo != "/home" {
+ t.Fatalf("returnTo = %q, want /home", entry.ReturnTo)
+ }
+}
+
+func TestGetOIDCService_ReturnsCachedService(t *testing.T) {
+ cfg := newTestOIDCConfig()
+ svc := newTestOIDCService(cfg, "https://auth.example.com/authorize", "https://token.example.com")
+ router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
+ defer svc.stateStore.Stop()
+
+ got, err := router.getOIDCService(context.Background(), cfg.RedirectURL)
+ if err != nil {
+ t.Fatalf("getOIDCService error: %v", err)
+ }
+ if got != svc {
+ t.Fatalf("expected cached service to be returned")
+ }
+}
+
+func TestHandleOIDCCallback_MethodNotAllowed(t *testing.T) {
+ router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
+ req := httptest.NewRequest(http.MethodPost, "/api/oidc/callback", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusMethodNotAllowed {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
+ }
+}
+
+func TestHandleOIDCCallback_ErrorParam(t *testing.T) {
+ router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?error=access_denied", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ if !strings.Contains(location, "oidc_error=access_denied") {
+ t.Fatalf("unexpected redirect location: %q", location)
+ }
+}
+
+func TestHandleOIDCCallback_MissingState(t *testing.T) {
+ router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?code=abc", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ if !strings.Contains(location, "oidc_error=missing_state") {
+ t.Fatalf("unexpected redirect location: %q", location)
+ }
+}
+
+func TestHandleOIDCCallback_InvalidState(t *testing.T) {
+ router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state=invalid", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ if !strings.Contains(location, "oidc_error=invalid_state") {
+ t.Fatalf("unexpected redirect location: %q", location)
+ }
+}
+
+func TestHandleOIDCCallback_MissingCode(t *testing.T) {
+ router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
+ state, _, err := svc.newStateEntry("/dashboard")
+ if err != nil {
+ t.Fatalf("newStateEntry error: %v", err)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state), nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ if !strings.Contains(location, "oidc_error=missing_code") {
+ t.Fatalf("unexpected redirect location: %q", location)
+ }
+ if !strings.HasPrefix(location, "/dashboard") {
+ t.Fatalf("expected redirect back to /dashboard, got %q", location)
+ }
+}
+
+func TestHandleOIDCCallback_ExchangeFailed(t *testing.T) {
+ tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer tokenServer.Close()
+
+ router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
+ svc.httpClient = tokenServer.Client()
+ state, _, err := svc.newStateEntry("/dashboard")
+ if err != nil {
+ t.Fatalf("newStateEntry error: %v", err)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ if !strings.Contains(location, "oidc_error=exchange_failed") {
+ t.Fatalf("unexpected redirect location: %q", location)
+ }
+}
+
+func TestHandleOIDCCallback_MissingIDToken(t *testing.T) {
+ tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"access_token":"access","token_type":"Bearer","expires_in":3600}`))
+ }))
+ defer tokenServer.Close()
+
+ router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
+ svc.httpClient = tokenServer.Client()
+ state, _, err := svc.newStateEntry("/dashboard")
+ if err != nil {
+ t.Fatalf("newStateEntry error: %v", err)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleOIDCCallback(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
+ }
+ location := rec.Header().Get("Location")
+ if !strings.Contains(location, "oidc_error=missing_id_token") {
+ t.Fatalf("unexpected redirect location: %q", location)
+ }
+}
diff --git a/internal/api/router_auth_additional_test.go b/internal/api/router_auth_additional_test.go
new file mode 100644
index 000000000..ceea12745
--- /dev/null
+++ b/internal/api/router_auth_additional_test.go
@@ -0,0 +1,161 @@
+package api
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/rcourtman/pulse-go-rewrite/internal/config"
+ "github.com/rcourtman/pulse-go-rewrite/pkg/auth"
+)
+
+func newAuthRouter(t *testing.T) *Router {
+ t.Helper()
+ hashed, err := auth.HashPassword("currentpassword")
+ if err != nil {
+ t.Fatalf("hash password: %v", err)
+ }
+ return &Router{
+ config: &config.Config{
+ AuthUser: "admin",
+ AuthPass: hashed,
+ ConfigPath: t.TempDir(),
+ },
+ }
+}
+
+func TestHandleChangePassword_InvalidJSON(t *testing.T) {
+ router := newAuthRouter(t)
+ req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader("{"))
+ rec := httptest.NewRecorder()
+
+ router.handleChangePassword(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
+ }
+ var payload map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if payload["code"] != "invalid_request" {
+ t.Fatalf("expected invalid_request, got %#v", payload["code"])
+ }
+}
+
+func TestHandleChangePassword_InvalidPassword(t *testing.T) {
+ router := newAuthRouter(t)
+ body := `{"currentPassword":"currentpassword","newPassword":"short"}`
+ req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
+ rec := httptest.NewRecorder()
+
+ router.handleChangePassword(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
+ }
+ var payload map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if payload["code"] != "invalid_password" {
+ t.Fatalf("expected invalid_password, got %#v", payload["code"])
+ }
+}
+
+func TestHandleChangePassword_MissingCurrent(t *testing.T) {
+ router := newAuthRouter(t)
+ body := `{"currentPassword":"","newPassword":"newpassword123"}`
+ req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
+ rec := httptest.NewRecorder()
+
+ router.handleChangePassword(rec, req)
+
+ if rec.Code != http.StatusUnauthorized {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
+ }
+ var payload map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if payload["code"] != "unauthorized" {
+ t.Fatalf("expected unauthorized, got %#v", payload["code"])
+ }
+}
+
+func TestHandleChangePassword_SuccessDocker(t *testing.T) {
+ router := newAuthRouter(t)
+ t.Setenv("PULSE_DOCKER", "true")
+
+ body := `{"currentPassword":"currentpassword","newPassword":"newpassword123"}`
+ req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
+ rec := httptest.NewRecorder()
+
+ router.handleChangePassword(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+ var payload map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if ok, _ := payload["success"].(bool); !ok {
+ t.Fatalf("expected success=true, got %#v", payload["success"])
+ }
+ envPath := filepath.Join(router.config.ConfigPath, ".env")
+ if _, err := os.Stat(envPath); err != nil {
+ t.Fatalf("expected .env to be written, got error: %v", err)
+ }
+}
+
+func TestHandleResetLockout_InvalidJSON(t *testing.T) {
+ router := newAuthRouter(t)
+ req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader("{"))
+ req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
+ rec := httptest.NewRecorder()
+
+ router.handleResetLockout(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
+ }
+}
+
+func TestHandleResetLockout_MissingIdentifier(t *testing.T) {
+ router := newAuthRouter(t)
+ req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":""}`))
+ req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
+ rec := httptest.NewRecorder()
+
+ router.handleResetLockout(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
+ }
+}
+
+func TestHandleResetLockout_Success(t *testing.T) {
+ router := newAuthRouter(t)
+ req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":"user1"}`))
+ req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
+ rec := httptest.NewRecorder()
+
+ router.handleResetLockout(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+ var payload map[string]interface{}
+ if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if ok, _ := payload["success"].(bool); !ok {
+ t.Fatalf("expected success=true, got %#v", payload["success"])
+ }
+}
diff --git a/internal/api/saml_handlers_more_test.go b/internal/api/saml_handlers_more_test.go
new file mode 100644
index 000000000..0ffd08acc
--- /dev/null
+++ b/internal/api/saml_handlers_more_test.go
@@ -0,0 +1,217 @@
+package api
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/rcourtman/pulse-go-rewrite/internal/config"
+)
+
+func newTestSAMLService(t *testing.T, providerID string, metadataXML string) *SAMLService {
+ t.Helper()
+ service, err := NewSAMLService(context.Background(), providerID, &config.SAMLProviderConfig{
+ IDPMetadataXML: metadataXML,
+ }, "https://pulse.example.com")
+ if err != nil {
+ t.Fatalf("NewSAMLService: %v", err)
+ }
+ return service
+}
+
+func TestHandleSAMLACS_ProcessResponseError(t *testing.T) {
+ router := newSAMLRouter(t, testSAMLProvider("okta", true))
+ router.samlManager.services["okta"] = &SAMLService{}
+
+ req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleSAMLACS(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
+ }
+ if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=saml_validation_failed") {
+ t.Fatalf("expected validation failed redirect, got %q", loc)
+ }
+}
+
+func TestHandleSAMLMetadata_InvalidMethod(t *testing.T) {
+ router := newSAMLRouter(t, testSAMLProvider("okta", true))
+ req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/metadata", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleSAMLMetadata(rec, req)
+
+ if rec.Code != http.StatusMethodNotAllowed {
+ t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
+ }
+}
+
+func TestHandleSAMLMetadata_InvalidProviderID(t *testing.T) {
+ router := newSAMLRouter(t, testSAMLProvider("okta", true))
+ req := httptest.NewRequest(http.MethodGet, "/api/saml/invalid$id/metadata", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleSAMLMetadata(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
+ }
+}
+
+func TestGetSAMLSessionInfo_NoCookie(t *testing.T) {
+ router := &Router{}
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+
+ if info := router.getSAMLSessionInfo(req); info != nil {
+ t.Fatalf("expected nil session info without cookie")
+ }
+}
+
+func TestGetSAMLSessionInfo_ReturnsInfo(t *testing.T) {
+ InitSessionStore(t.TempDir())
+
+ token := generateSessionToken()
+ GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
+ ProviderID: "okta",
+ NameID: "name-id",
+ SessionIndex: "sess-1",
+ })
+
+ router := &Router{}
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
+
+ info := router.getSAMLSessionInfo(req)
+ if info == nil {
+ t.Fatalf("expected session info")
+ }
+ if info.ProviderID != "okta" || info.NameID != "name-id" || info.SessionIndex != "sess-1" {
+ t.Fatalf("unexpected session info: %#v", info)
+ }
+}
+
+func TestClearSession(t *testing.T) {
+ router := &Router{}
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ router.clearSession(rec, req)
+
+ cookies := rec.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("expected 1 cookie, got %d", len(cookies))
+ }
+ cookie := cookies[0]
+ if cookie.Name != "session" {
+ t.Fatalf("expected session cookie name, got %q", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Fatalf("expected MaxAge -1, got %d", cookie.MaxAge)
+ }
+ if !cookie.HttpOnly {
+ t.Fatalf("expected HttpOnly cookie")
+ }
+}
+
+func TestHandleSAMLSLO_Redirects(t *testing.T) {
+ router := &Router{}
+ req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/slo", nil)
+ rec := httptest.NewRecorder()
+
+ router.handleSAMLSLO(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
+ }
+ if loc := rec.Header().Get("Location"); loc != "/?logout=success" {
+ t.Fatalf("unexpected redirect location %q", loc)
+ }
+}
+
+func TestHandleSAMLLogout_SLOUnavailable(t *testing.T) {
+ InitSessionStore(t.TempDir())
+
+ router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
+ metadataXML := `
+
+
+
+
+`
+ router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
+
+ token := generateSessionToken()
+ GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
+ ProviderID: "okta",
+ NameID: "name-id",
+ SessionIndex: "sess-1",
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
+ req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
+ rec := httptest.NewRecorder()
+
+ router.handleSAMLLogout(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
+ }
+ if loc := rec.Header().Get("Location"); loc != "/?logout=success" {
+ t.Fatalf("unexpected redirect location %q", loc)
+ }
+}
+
+func TestHandleSAMLLogout_SLOSuccess(t *testing.T) {
+ InitSessionStore(t.TempDir())
+
+ router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
+ metadataXML := `
+
+
+
+
+
+`
+ router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
+
+ token := generateSessionToken()
+ GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
+ ProviderID: "okta",
+ NameID: "name-id",
+ SessionIndex: "sess-1",
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
+ req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
+ rec := httptest.NewRecorder()
+
+ router.handleSAMLLogout(rec, req)
+
+ if rec.Code != http.StatusFound {
+ t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
+ }
+ loc := rec.Header().Get("Location")
+ if !strings.Contains(loc, "https://idp.example.com/slo") || !strings.Contains(loc, "SAMLRequest=") {
+ t.Fatalf("unexpected SLO redirect location %q", loc)
+ }
+}
+
+func TestExtractSAMLProviderID(t *testing.T) {
+ if got := extractSAMLProviderID("/api/saml/okta/login", "login"); got != "okta" {
+ t.Fatalf("expected okta, got %q", got)
+ }
+ if got := extractSAMLProviderID("/api/saml/okta/logout", "login"); got != "" {
+ t.Fatalf("expected empty provider, got %q", got)
+ }
+ if got := extractSAMLProviderID("/api/saml/okta/login/extra", "login"); got != "okta" {
+ t.Fatalf("expected okta for extra path, got %q", got)
+ }
+ if got := extractSAMLProviderID("/api/other/okta/login", "login"); got != "" {
+ t.Fatalf("expected empty provider for non-saml path, got %q", got)
+ }
+}
diff --git a/internal/api/saml_service_additional_test.go b/internal/api/saml_service_additional_test.go
new file mode 100644
index 000000000..9c85987d8
--- /dev/null
+++ b/internal/api/saml_service_additional_test.go
@@ -0,0 +1,133 @@
+package api
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/crewjam/saml"
+ "github.com/rcourtman/pulse-go-rewrite/internal/config"
+)
+
+func TestParseIDPMetadataXML_EmptyEntities(t *testing.T) {
+ wrapped := `
+
+`
+ if _, err := parseIDPMetadataXML([]byte(wrapped)); err == nil {
+ t.Fatal("expected error for empty entities descriptor")
+ }
+}
+
+func TestExtractAttribute(t *testing.T) {
+ service := &SAMLService{}
+ attrs := map[string][]string{
+ "email": {"user@example.com"},
+ }
+
+ if got := service.extractAttribute(attrs, "", "fallback"); got != "fallback" {
+ t.Fatalf("expected fallback value, got %q", got)
+ }
+ if got := service.extractAttribute(attrs, "email", ""); got != "user@example.com" {
+ t.Fatalf("unexpected attribute value: %q", got)
+ }
+ if got := service.extractAttribute(attrs, "missing", "default"); got != "default" {
+ t.Fatalf("unexpected missing attribute value: %q", got)
+ }
+}
+
+func TestProcessResponse_InvalidResponse(t *testing.T) {
+ service := &SAMLService{
+ sp: &saml.ServiceProvider{},
+ }
+
+ body := strings.NewReader("RelayState=/dashboard&SAMLResponse=")
+ req := httptest.NewRequest(http.MethodPost, "/acs", body)
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ _, relay, err := service.ProcessResponse(req)
+ if err == nil {
+ t.Fatal("expected error for invalid response")
+ }
+ if relay != "/dashboard" {
+ t.Fatalf("unexpected relay state: %q", relay)
+ }
+}
+
+func TestSAMLServiceIdentifiers(t *testing.T) {
+ service := &SAMLService{
+ providerID: "provider-1",
+ sp: &saml.ServiceProvider{EntityID: "sp-entity"},
+ idpMetadata: &saml.EntityDescriptor{
+ EntityID: "idp-entity",
+ },
+ }
+
+ if service.ProviderID() != "provider-1" {
+ t.Fatalf("unexpected provider id")
+ }
+ if service.GetSPEntityID() != "sp-entity" {
+ t.Fatalf("unexpected sp entity id")
+ }
+ if service.GetIDPEntityID() != "idp-entity" {
+ t.Fatalf("unexpected idp entity id")
+ }
+
+ service.sp = nil
+ if service.GetSPEntityID() != "" {
+ t.Fatalf("expected empty sp entity id when nil")
+ }
+ service.idpMetadata = nil
+ if service.GetIDPEntityID() != "" {
+ t.Fatalf("expected empty idp entity id when nil")
+ }
+}
+
+func TestRefreshMetadata_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`
+
+
+`))
+ }))
+ defer server.Close()
+
+ cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL}
+ service := &SAMLService{
+ providerID: "idp",
+ config: cfg,
+ baseURL: "http://localhost",
+ httpClient: server.Client(),
+ }
+
+ if err := service.RefreshMetadata(context.Background()); err != nil {
+ t.Fatalf("refresh metadata: %v", err)
+ }
+ if service.sp == nil {
+ t.Fatal("expected service provider to be initialized")
+ }
+}
+
+func TestFetchIDPMetadataFromURL_NonOK(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ service := &SAMLService{httpClient: server.Client()}
+ if _, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL); err == nil {
+ t.Fatal("expected error for non-200 status")
+ }
+}
+
+func TestAddIDPCertificate_InvalidPEM(t *testing.T) {
+ service := &SAMLService{config: &config.SAMLProviderConfig{IDPCertificate: "not-pem"}}
+ metadata := &saml.EntityDescriptor{
+ IDPSSODescriptors: []saml.IDPSSODescriptor{{}},
+ }
+ if err := service.addIDPCertificate(metadata); err == nil {
+ t.Fatal("expected error for invalid certificate pem")
+ }
+}
diff --git a/internal/api/security_oidc_handlers_additional_test.go b/internal/api/security_oidc_handlers_additional_test.go
new file mode 100644
index 000000000..072f1c841
--- /dev/null
+++ b/internal/api/security_oidc_handlers_additional_test.go
@@ -0,0 +1,95 @@
+package api
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/rcourtman/pulse-go-rewrite/internal/config"
+)
+
+func TestSecurityOIDCHandlers_GetConfig(t *testing.T) {
+ cfg := &config.Config{
+ PublicURL: "https://pulse.example.com",
+ OIDC: &config.OIDCConfig{
+ Enabled: true,
+ IssuerURL: "https://issuer.example.com",
+ ClientID: "client-id",
+ ClientSecret: "super-secret",
+ RedirectURL: "https://pulse.example.com/oidc/callback",
+ Scopes: []string{"openid"},
+ UsernameClaim: "sub",
+ EmailClaim: "email",
+ GroupsClaim: "groups",
+ },
+ }
+ router := &Router{config: cfg}
+
+ req := httptest.NewRequest(http.MethodGet, "/api/security/oidc", nil)
+ rr := httptest.NewRecorder()
+
+ router.handleOIDCConfig(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
+ }
+
+ var resp oidcResponse
+ if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+ if !resp.Enabled {
+ t.Fatalf("expected OIDC to be enabled")
+ }
+ if resp.IssuerURL != cfg.OIDC.IssuerURL {
+ t.Fatalf("unexpected issuer url")
+ }
+ if !resp.ClientSecretSet {
+ t.Fatalf("expected client secret to be marked as set")
+ }
+}
+
+func TestSecurityOIDCHandlers_UpdateSaveFailure(t *testing.T) {
+ cfg := &config.Config{
+ PublicURL: "https://pulse.example.com",
+ OIDC: &config.OIDCConfig{
+ ClientSecret: "old-secret",
+ CABundle: "old-bundle",
+ },
+ }
+ router := &Router{config: cfg}
+
+ payload := map[string]any{
+ "enabled": true,
+ "issuerUrl": "https://issuer.example.com",
+ "clientId": "client-id",
+ "clientSecret": "new-secret",
+ "redirectUrl": "https://pulse.example.com/oidc/callback",
+ "scopes": []string{"openid"},
+ "usernameClaim": "sub",
+ "emailClaim": "email",
+ "groupsClaim": "groups",
+ "clearClientSecret": true,
+ "caBundle": "new-bundle",
+ }
+ body, _ := json.Marshal(payload)
+
+ req := httptest.NewRequest(http.MethodPut, "/api/security/oidc", bytes.NewReader(body))
+ rr := httptest.NewRecorder()
+
+ router.handleOIDCConfig(rr, req)
+
+ if rr.Code != http.StatusInternalServerError {
+ t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
+ }
+
+ var apiErr APIError
+ if err := json.NewDecoder(rr.Body).Decode(&apiErr); err != nil {
+ t.Fatalf("decode error response: %v", err)
+ }
+ if apiErr.Code != "save_failed" {
+ t.Fatalf("unexpected error code: %s", apiErr.Code)
+ }
+}