Files
Pulse/internal/api/saml_handlers_additional_test.go

147 lines
4.6 KiB
Go

package api
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func TestSAMLServiceManager_RemoveProvider(t *testing.T) {
manager := NewSAMLServiceManager("https://pulse.example.com")
manager.services["okta"] = &SAMLService{}
manager.RemoveProvider("okta")
if svc := manager.GetService("okta"); svc != nil {
t.Fatalf("expected provider to be removed")
}
}
func TestHandleSAMLACS_MethodNotAllowed(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))
if router.samlManager == nil {
router.samlManager = NewSAMLServiceManager("https://pulse.example.com")
}
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/acs", nil)
rec := httptest.NewRecorder()
router.handleSAMLACS(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
}
}
func TestHandleSAMLACS_InvalidProviderID(t *testing.T) {
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
req := httptest.NewRequest(http.MethodPost, "/api/saml/invalid$id/acs", nil)
rec := httptest.NewRecorder()
router.handleSAMLACS(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=invalid_provider") {
t.Fatalf("expected invalid_provider redirect, got %q", loc)
}
}
func TestHandleSAMLACS_ProviderNotFound(t *testing.T) {
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
rec := httptest.NewRecorder()
router.handleSAMLACS(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=provider_not_found") {
t.Fatalf("expected provider_not_found redirect, got %q", loc)
}
}
func TestHandleSAMLACS_ServiceNotInitialized(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
rec := httptest.NewRecorder()
router.handleSAMLACS(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=provider_not_initialized") {
t.Fatalf("expected provider_not_initialized redirect, got %q", loc)
}
}
func TestHandleSAMLLogout_Fallback(t *testing.T) {
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
rec := httptest.NewRecorder()
router.handleSAMLLogout(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
}
}
func TestEstablishSAMLSession(t *testing.T) {
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
samlInfo := &SAMLSessionInfo{ProviderID: "okta", NameID: "user", SessionIndex: "sess-1"}
if err := router.establishSAMLSession(rec, req, "admin", samlInfo); err != nil {
t.Fatalf("establishSAMLSession error: %v", err)
}
cookies := rec.Result().Cookies()
if len(cookies) < 2 {
t.Fatalf("expected session and csrf cookies, got %d", len(cookies))
}
}
func TestRedirectSAMLError(t *testing.T) {
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
router.redirectSAMLError(rec, req, "/dashboard", "session_failed")
if rec.Code != http.StatusFound {
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
}
loc := rec.Header().Get("Location")
if !strings.Contains(loc, "saml=error") || !strings.Contains(loc, "saml_error=session_failed") {
t.Fatalf("unexpected redirect location %q", loc)
}
}
func TestInitializeSAMLProviders(t *testing.T) {
provider := testSAMLProvider("okta", true)
provider.SAML = &config.SAMLProviderConfig{
IDPSSOURL: "https://idp.example.com/sso",
IDPEntityID: "https://idp.example.com/metadata",
}
router := newSAMLRouter(t, provider)
if router.samlManager == nil {
router.samlManager = NewSAMLServiceManager("https://pulse.example.com")
}
if err := router.InitializeSAMLProviders(reqContext(t)); err != nil {
t.Fatalf("InitializeSAMLProviders error: %v", err)
}
if svc := router.samlManager.GetService("okta"); svc == nil {
t.Fatalf("expected SAML service to be initialized")
}
}