Test: add coverage for auth and security handlers

Add additional tests for OIDC, SAML, and tenant middleware to improve coverage of security-critical paths.
This commit is contained in:
rcourtman
2026-02-02 22:02:11 +00:00
parent 97a985efb8
commit 43d7fffeef
6 changed files with 1063 additions and 0 deletions

View File

@@ -0,0 +1,139 @@
package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func TestNewTenantMiddlewareWithConfig(t *testing.T) {
persistence := config.NewMultiTenantPersistence(t.TempDir())
checker := stubAuthorizationChecker{}
mw := NewTenantMiddlewareWithConfig(TenantMiddlewareConfig{
Persistence: persistence,
AuthChecker: checker,
})
if mw.persistence != persistence {
t.Fatalf("expected persistence to be set")
}
if mw.authChecker == nil {
t.Fatalf("expected auth checker to be set")
}
}
type stubAuthorizationChecker struct{}
func (stubAuthorizationChecker) TokenCanAccessOrg(*config.APITokenRecord, string) bool {
return true
}
func (stubAuthorizationChecker) UserCanAccessOrg(string, string) bool {
return true
}
func (stubAuthorizationChecker) CheckAccess(*config.APITokenRecord, string, string) AuthorizationResult {
return AuthorizationResult{Allowed: true}
}
func TestWriteJSONError(t *testing.T) {
rec := httptest.NewRecorder()
writeJSONError(rec, http.StatusBadRequest, "bad", "message")
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rec.Code)
}
var payload map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload["error"] != "bad" || payload["message"] != "message" {
t.Fatalf("unexpected payload: %+v", payload)
}
}
func TestTenantMiddleware_OrgExtraction(t *testing.T) {
mw := NewTenantMiddleware(nil)
handler := mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orgID := GetOrgID(r.Context())
org := GetOrganization(r.Context())
if orgID == "" || org == nil || org.ID != orgID {
t.Fatalf("unexpected org context: %q %+v", orgID, org)
}
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Pulse-Org-ID", "header-org")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "pulse_org_id", Value: "cookie-org"})
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
}
func TestTenantMiddleware_InvalidOrg(t *testing.T) {
persistence := config.NewMultiTenantPersistence(t.TempDir())
mw := NewTenantMiddleware(persistence)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Pulse-Org-ID", "../bad")
rec := httptest.NewRecorder()
mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next handler should not be called")
})).ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rec.Code)
}
}
func TestTenantMiddleware_MultiTenantDisabled(t *testing.T) {
orig := IsMultiTenantEnabled()
SetMultiTenantEnabled(false)
t.Cleanup(func() { SetMultiTenantEnabled(orig) })
mw := NewTenantMiddleware(nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Pulse-Org-ID", "tenant-1")
rec := httptest.NewRecorder()
mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next handler should not be called")
})).ServeHTTP(rec, req)
if rec.Code != http.StatusNotImplemented {
t.Fatalf("expected 501, got %d", rec.Code)
}
}
func TestTenantMiddleware_MultiTenantLicenseRequired(t *testing.T) {
orig := IsMultiTenantEnabled()
SetMultiTenantEnabled(true)
t.Cleanup(func() { SetMultiTenantEnabled(orig) })
mw := NewTenantMiddleware(nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Pulse-Org-ID", "tenant-2")
rec := httptest.NewRecorder()
mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next handler should not be called")
})).ServeHTTP(rec, req)
if rec.Code != http.StatusPaymentRequired {
t.Fatalf("expected 402, got %d", rec.Code)
}
}

View File

@@ -0,0 +1,318 @@
package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"golang.org/x/oauth2"
)
func newTestOIDCConfig() *config.OIDCConfig {
return &config.OIDCConfig{
Enabled: true,
IssuerURL: "https://issuer.example.com",
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "https://app.example.com/api/oidc/callback",
Scopes: []string{"openid", "email"},
UsernameClaim: "preferred_username",
EmailClaim: "email",
GroupsClaim: "groups",
}
}
func newTestOIDCService(cfg *config.OIDCConfig, authURL, tokenURL string) *OIDCService {
return &OIDCService{
snapshot: oidcSnapshot{
issuer: cfg.IssuerURL,
clientID: cfg.ClientID,
clientSecret: cfg.ClientSecret,
redirectURL: cfg.RedirectURL,
scopes: append([]string{}, cfg.Scopes...),
caBundle: cfg.CABundle,
},
oauth2Cfg: &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
RedirectURL: cfg.RedirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
},
Scopes: append([]string{}, cfg.Scopes...),
},
stateStore: newOIDCStateStore(),
}
}
func newOIDCRouterWithService(t *testing.T, authURL, tokenURL string) (*Router, *OIDCService) {
t.Helper()
cfg := newTestOIDCConfig()
svc := newTestOIDCService(cfg, authURL, tokenURL)
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
t.Cleanup(func() {
if svc.stateStore != nil {
svc.stateStore.Stop()
}
})
return router, svc
}
func TestHandleOIDCLogin_MethodNotAllowed(t *testing.T) {
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
req := httptest.NewRequest(http.MethodPut, "/api/oidc/login", nil)
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
}
func TestHandleOIDCLogin_InvalidJSON(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", strings.NewReader("{"))
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload["code"] != "invalid_request" {
t.Fatalf("code = %v, want invalid_request", payload["code"])
}
}
func TestHandleOIDCLogin_GetSuccess(t *testing.T) {
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/login?returnTo=/dashboard", nil)
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
u, err := url.Parse(location)
if err != nil {
t.Fatalf("parse location: %v", err)
}
if u.Host != "auth.example.com" {
t.Fatalf("unexpected auth host: %q", u.Host)
}
state := u.Query().Get("state")
if state == "" {
t.Fatalf("expected state param in redirect")
}
entry, ok := svc.consumeState(state)
if !ok {
t.Fatalf("expected state entry to be stored")
}
if entry.ReturnTo != "/dashboard" {
t.Fatalf("returnTo = %q, want /dashboard", entry.ReturnTo)
}
}
func TestHandleOIDCLogin_PostSuccess(t *testing.T) {
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
body := strings.NewReader(`{"returnTo":"/home"}`)
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", body)
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var payload struct {
AuthorizationURL string `json:"authorizationUrl"`
}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload.AuthorizationURL == "" {
t.Fatalf("expected authorizationUrl in response")
}
u, err := url.Parse(payload.AuthorizationURL)
if err != nil {
t.Fatalf("parse authorizationUrl: %v", err)
}
state := u.Query().Get("state")
if state == "" {
t.Fatalf("expected state param in authorizationUrl")
}
entry, ok := svc.consumeState(state)
if !ok {
t.Fatalf("expected state entry to be stored")
}
if entry.ReturnTo != "/home" {
t.Fatalf("returnTo = %q, want /home", entry.ReturnTo)
}
}
func TestGetOIDCService_ReturnsCachedService(t *testing.T) {
cfg := newTestOIDCConfig()
svc := newTestOIDCService(cfg, "https://auth.example.com/authorize", "https://token.example.com")
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
defer svc.stateStore.Stop()
got, err := router.getOIDCService(context.Background(), cfg.RedirectURL)
if err != nil {
t.Fatalf("getOIDCService error: %v", err)
}
if got != svc {
t.Fatalf("expected cached service to be returned")
}
}
func TestHandleOIDCCallback_MethodNotAllowed(t *testing.T) {
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
req := httptest.NewRequest(http.MethodPost, "/api/oidc/callback", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
}
func TestHandleOIDCCallback_ErrorParam(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?error=access_denied", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=access_denied") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_MissingState(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?code=abc", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=missing_state") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_InvalidState(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state=invalid", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=invalid_state") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_MissingCode(t *testing.T) {
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
state, _, err := svc.newStateEntry("/dashboard")
if err != nil {
t.Fatalf("newStateEntry error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=missing_code") {
t.Fatalf("unexpected redirect location: %q", location)
}
if !strings.HasPrefix(location, "/dashboard") {
t.Fatalf("expected redirect back to /dashboard, got %q", location)
}
}
func TestHandleOIDCCallback_ExchangeFailed(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer tokenServer.Close()
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
svc.httpClient = tokenServer.Client()
state, _, err := svc.newStateEntry("/dashboard")
if err != nil {
t.Fatalf("newStateEntry error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=exchange_failed") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_MissingIDToken(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"access","token_type":"Bearer","expires_in":3600}`))
}))
defer tokenServer.Close()
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
svc.httpClient = tokenServer.Client()
state, _, err := svc.newStateEntry("/dashboard")
if err != nil {
t.Fatalf("newStateEntry error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=missing_id_token") {
t.Fatalf("unexpected redirect location: %q", location)
}
}

View File

@@ -0,0 +1,161 @@
package api
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rcourtman/pulse-go-rewrite/pkg/auth"
)
func newAuthRouter(t *testing.T) *Router {
t.Helper()
hashed, err := auth.HashPassword("currentpassword")
if err != nil {
t.Fatalf("hash password: %v", err)
}
return &Router{
config: &config.Config{
AuthUser: "admin",
AuthPass: hashed,
ConfigPath: t.TempDir(),
},
}
}
func TestHandleChangePassword_InvalidJSON(t *testing.T) {
router := newAuthRouter(t)
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader("{"))
rec := httptest.NewRecorder()
router.handleChangePassword(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload["code"] != "invalid_request" {
t.Fatalf("expected invalid_request, got %#v", payload["code"])
}
}
func TestHandleChangePassword_InvalidPassword(t *testing.T) {
router := newAuthRouter(t)
body := `{"currentPassword":"currentpassword","newPassword":"short"}`
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
rec := httptest.NewRecorder()
router.handleChangePassword(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload["code"] != "invalid_password" {
t.Fatalf("expected invalid_password, got %#v", payload["code"])
}
}
func TestHandleChangePassword_MissingCurrent(t *testing.T) {
router := newAuthRouter(t)
body := `{"currentPassword":"","newPassword":"newpassword123"}`
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
rec := httptest.NewRecorder()
router.handleChangePassword(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload["code"] != "unauthorized" {
t.Fatalf("expected unauthorized, got %#v", payload["code"])
}
}
func TestHandleChangePassword_SuccessDocker(t *testing.T) {
router := newAuthRouter(t)
t.Setenv("PULSE_DOCKER", "true")
body := `{"currentPassword":"currentpassword","newPassword":"newpassword123"}`
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
rec := httptest.NewRecorder()
router.handleChangePassword(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if ok, _ := payload["success"].(bool); !ok {
t.Fatalf("expected success=true, got %#v", payload["success"])
}
envPath := filepath.Join(router.config.ConfigPath, ".env")
if _, err := os.Stat(envPath); err != nil {
t.Fatalf("expected .env to be written, got error: %v", err)
}
}
func TestHandleResetLockout_InvalidJSON(t *testing.T) {
router := newAuthRouter(t)
req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader("{"))
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
rec := httptest.NewRecorder()
router.handleResetLockout(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestHandleResetLockout_MissingIdentifier(t *testing.T) {
router := newAuthRouter(t)
req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":""}`))
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
rec := httptest.NewRecorder()
router.handleResetLockout(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestHandleResetLockout_Success(t *testing.T) {
router := newAuthRouter(t)
req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":"user1"}`))
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
rec := httptest.NewRecorder()
router.handleResetLockout(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if ok, _ := payload["success"].(bool); !ok {
t.Fatalf("expected success=true, got %#v", payload["success"])
}
}

View File

@@ -0,0 +1,217 @@
package api
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func newTestSAMLService(t *testing.T, providerID string, metadataXML string) *SAMLService {
t.Helper()
service, err := NewSAMLService(context.Background(), providerID, &config.SAMLProviderConfig{
IDPMetadataXML: metadataXML,
}, "https://pulse.example.com")
if err != nil {
t.Fatalf("NewSAMLService: %v", err)
}
return service
}
func TestHandleSAMLACS_ProcessResponseError(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))
router.samlManager.services["okta"] = &SAMLService{}
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
rec := httptest.NewRecorder()
router.handleSAMLACS(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=saml_validation_failed") {
t.Fatalf("expected validation failed redirect, got %q", loc)
}
}
func TestHandleSAMLMetadata_InvalidMethod(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/metadata", nil)
rec := httptest.NewRecorder()
router.handleSAMLMetadata(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
}
}
func TestHandleSAMLMetadata_InvalidProviderID(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))
req := httptest.NewRequest(http.MethodGet, "/api/saml/invalid$id/metadata", nil)
rec := httptest.NewRecorder()
router.handleSAMLMetadata(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
}
}
func TestGetSAMLSessionInfo_NoCookie(t *testing.T) {
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
if info := router.getSAMLSessionInfo(req); info != nil {
t.Fatalf("expected nil session info without cookie")
}
}
func TestGetSAMLSessionInfo_ReturnsInfo(t *testing.T) {
InitSessionStore(t.TempDir())
token := generateSessionToken()
GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
ProviderID: "okta",
NameID: "name-id",
SessionIndex: "sess-1",
})
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
info := router.getSAMLSessionInfo(req)
if info == nil {
t.Fatalf("expected session info")
}
if info.ProviderID != "okta" || info.NameID != "name-id" || info.SessionIndex != "sess-1" {
t.Fatalf("unexpected session info: %#v", info)
}
}
func TestClearSession(t *testing.T) {
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
router.clearSession(rec, req)
cookies := rec.Result().Cookies()
if len(cookies) != 1 {
t.Fatalf("expected 1 cookie, got %d", len(cookies))
}
cookie := cookies[0]
if cookie.Name != "session" {
t.Fatalf("expected session cookie name, got %q", cookie.Name)
}
if cookie.MaxAge != -1 {
t.Fatalf("expected MaxAge -1, got %d", cookie.MaxAge)
}
if !cookie.HttpOnly {
t.Fatalf("expected HttpOnly cookie")
}
}
func TestHandleSAMLSLO_Redirects(t *testing.T) {
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/slo", nil)
rec := httptest.NewRecorder()
router.handleSAMLSLO(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/?logout=success" {
t.Fatalf("unexpected redirect location %q", loc)
}
}
func TestHandleSAMLLogout_SLOUnavailable(t *testing.T) {
InitSessionStore(t.TempDir())
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
metadataXML := `<?xml version="1.0"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp">
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
</IDPSSODescriptor>
</EntityDescriptor>`
router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
token := generateSessionToken()
GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
ProviderID: "okta",
NameID: "name-id",
SessionIndex: "sess-1",
})
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
rec := httptest.NewRecorder()
router.handleSAMLLogout(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
if loc := rec.Header().Get("Location"); loc != "/?logout=success" {
t.Fatalf("unexpected redirect location %q", loc)
}
}
func TestHandleSAMLLogout_SLOSuccess(t *testing.T) {
InitSessionStore(t.TempDir())
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
metadataXML := `<?xml version="1.0"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp">
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
<SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
</IDPSSODescriptor>
</EntityDescriptor>`
router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
token := generateSessionToken()
GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
ProviderID: "okta",
NameID: "name-id",
SessionIndex: "sess-1",
})
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
rec := httptest.NewRecorder()
router.handleSAMLLogout(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
loc := rec.Header().Get("Location")
if !strings.Contains(loc, "https://idp.example.com/slo") || !strings.Contains(loc, "SAMLRequest=") {
t.Fatalf("unexpected SLO redirect location %q", loc)
}
}
func TestExtractSAMLProviderID(t *testing.T) {
if got := extractSAMLProviderID("/api/saml/okta/login", "login"); got != "okta" {
t.Fatalf("expected okta, got %q", got)
}
if got := extractSAMLProviderID("/api/saml/okta/logout", "login"); got != "" {
t.Fatalf("expected empty provider, got %q", got)
}
if got := extractSAMLProviderID("/api/saml/okta/login/extra", "login"); got != "okta" {
t.Fatalf("expected okta for extra path, got %q", got)
}
if got := extractSAMLProviderID("/api/other/okta/login", "login"); got != "" {
t.Fatalf("expected empty provider for non-saml path, got %q", got)
}
}

View File

@@ -0,0 +1,133 @@
package api
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/crewjam/saml"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func TestParseIDPMetadataXML_EmptyEntities(t *testing.T) {
wrapped := `<?xml version="1.0"?>
<EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
</EntitiesDescriptor>`
if _, err := parseIDPMetadataXML([]byte(wrapped)); err == nil {
t.Fatal("expected error for empty entities descriptor")
}
}
func TestExtractAttribute(t *testing.T) {
service := &SAMLService{}
attrs := map[string][]string{
"email": {"user@example.com"},
}
if got := service.extractAttribute(attrs, "", "fallback"); got != "fallback" {
t.Fatalf("expected fallback value, got %q", got)
}
if got := service.extractAttribute(attrs, "email", ""); got != "user@example.com" {
t.Fatalf("unexpected attribute value: %q", got)
}
if got := service.extractAttribute(attrs, "missing", "default"); got != "default" {
t.Fatalf("unexpected missing attribute value: %q", got)
}
}
func TestProcessResponse_InvalidResponse(t *testing.T) {
service := &SAMLService{
sp: &saml.ServiceProvider{},
}
body := strings.NewReader("RelayState=/dashboard&SAMLResponse=")
req := httptest.NewRequest(http.MethodPost, "/acs", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
_, relay, err := service.ProcessResponse(req)
if err == nil {
t.Fatal("expected error for invalid response")
}
if relay != "/dashboard" {
t.Fatalf("unexpected relay state: %q", relay)
}
}
func TestSAMLServiceIdentifiers(t *testing.T) {
service := &SAMLService{
providerID: "provider-1",
sp: &saml.ServiceProvider{EntityID: "sp-entity"},
idpMetadata: &saml.EntityDescriptor{
EntityID: "idp-entity",
},
}
if service.ProviderID() != "provider-1" {
t.Fatalf("unexpected provider id")
}
if service.GetSPEntityID() != "sp-entity" {
t.Fatalf("unexpected sp entity id")
}
if service.GetIDPEntityID() != "idp-entity" {
t.Fatalf("unexpected idp entity id")
}
service.sp = nil
if service.GetSPEntityID() != "" {
t.Fatalf("expected empty sp entity id when nil")
}
service.idpMetadata = nil
if service.GetIDPEntityID() != "" {
t.Fatalf("expected empty idp entity id when nil")
}
}
func TestRefreshMetadata_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp-refresh">
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"></IDPSSODescriptor>
</EntityDescriptor>`))
}))
defer server.Close()
cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL}
service := &SAMLService{
providerID: "idp",
config: cfg,
baseURL: "http://localhost",
httpClient: server.Client(),
}
if err := service.RefreshMetadata(context.Background()); err != nil {
t.Fatalf("refresh metadata: %v", err)
}
if service.sp == nil {
t.Fatal("expected service provider to be initialized")
}
}
func TestFetchIDPMetadataFromURL_NonOK(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
service := &SAMLService{httpClient: server.Client()}
if _, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL); err == nil {
t.Fatal("expected error for non-200 status")
}
}
func TestAddIDPCertificate_InvalidPEM(t *testing.T) {
service := &SAMLService{config: &config.SAMLProviderConfig{IDPCertificate: "not-pem"}}
metadata := &saml.EntityDescriptor{
IDPSSODescriptors: []saml.IDPSSODescriptor{{}},
}
if err := service.addIDPCertificate(metadata); err == nil {
t.Fatal("expected error for invalid certificate pem")
}
}

View File

@@ -0,0 +1,95 @@
package api
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func TestSecurityOIDCHandlers_GetConfig(t *testing.T) {
cfg := &config.Config{
PublicURL: "https://pulse.example.com",
OIDC: &config.OIDCConfig{
Enabled: true,
IssuerURL: "https://issuer.example.com",
ClientID: "client-id",
ClientSecret: "super-secret",
RedirectURL: "https://pulse.example.com/oidc/callback",
Scopes: []string{"openid"},
UsernameClaim: "sub",
EmailClaim: "email",
GroupsClaim: "groups",
},
}
router := &Router{config: cfg}
req := httptest.NewRequest(http.MethodGet, "/api/security/oidc", nil)
rr := httptest.NewRecorder()
router.handleOIDCConfig(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
}
var resp oidcResponse
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if !resp.Enabled {
t.Fatalf("expected OIDC to be enabled")
}
if resp.IssuerURL != cfg.OIDC.IssuerURL {
t.Fatalf("unexpected issuer url")
}
if !resp.ClientSecretSet {
t.Fatalf("expected client secret to be marked as set")
}
}
func TestSecurityOIDCHandlers_UpdateSaveFailure(t *testing.T) {
cfg := &config.Config{
PublicURL: "https://pulse.example.com",
OIDC: &config.OIDCConfig{
ClientSecret: "old-secret",
CABundle: "old-bundle",
},
}
router := &Router{config: cfg}
payload := map[string]any{
"enabled": true,
"issuerUrl": "https://issuer.example.com",
"clientId": "client-id",
"clientSecret": "new-secret",
"redirectUrl": "https://pulse.example.com/oidc/callback",
"scopes": []string{"openid"},
"usernameClaim": "sub",
"emailClaim": "email",
"groupsClaim": "groups",
"clearClientSecret": true,
"caBundle": "new-bundle",
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest(http.MethodPut, "/api/security/oidc", bytes.NewReader(body))
rr := httptest.NewRecorder()
router.handleOIDCConfig(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
}
var apiErr APIError
if err := json.NewDecoder(rr.Body).Decode(&apiErr); err != nil {
t.Fatalf("decode error response: %v", err)
}
if apiErr.Code != "save_failed" {
t.Fatalf("unexpected error code: %s", apiErr.Code)
}
}