mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
147 lines
4.6 KiB
Go
147 lines
4.6 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
)
|
|
|
|
func TestSAMLServiceManager_RemoveProvider(t *testing.T) {
|
|
manager := NewSAMLServiceManager("https://pulse.example.com")
|
|
manager.services["okta"] = &SAMLService{}
|
|
|
|
manager.RemoveProvider("okta")
|
|
|
|
if svc := manager.GetService("okta"); svc != nil {
|
|
t.Fatalf("expected provider to be removed")
|
|
}
|
|
}
|
|
|
|
func TestHandleSAMLACS_MethodNotAllowed(t *testing.T) {
|
|
router := newSAMLRouter(t, testSAMLProvider("okta", true))
|
|
if router.samlManager == nil {
|
|
router.samlManager = NewSAMLServiceManager("https://pulse.example.com")
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/acs", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleSAMLACS(rec, req)
|
|
|
|
if rec.Code != http.StatusMethodNotAllowed {
|
|
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandleSAMLACS_InvalidProviderID(t *testing.T) {
|
|
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
|
|
req := httptest.NewRequest(http.MethodPost, "/api/saml/invalid$id/acs", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleSAMLACS(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
|
}
|
|
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=invalid_provider") {
|
|
t.Fatalf("expected invalid_provider redirect, got %q", loc)
|
|
}
|
|
}
|
|
|
|
func TestHandleSAMLACS_ProviderNotFound(t *testing.T) {
|
|
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
|
|
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleSAMLACS(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
|
}
|
|
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=provider_not_found") {
|
|
t.Fatalf("expected provider_not_found redirect, got %q", loc)
|
|
}
|
|
}
|
|
|
|
func TestHandleSAMLACS_ServiceNotInitialized(t *testing.T) {
|
|
router := newSAMLRouter(t, testSAMLProvider("okta", true))
|
|
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleSAMLACS(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
|
}
|
|
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=provider_not_initialized") {
|
|
t.Fatalf("expected provider_not_initialized redirect, got %q", loc)
|
|
}
|
|
}
|
|
|
|
func TestHandleSAMLLogout_Fallback(t *testing.T) {
|
|
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
|
|
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleSAMLLogout(rec, req)
|
|
|
|
if rec.Code != http.StatusMethodNotAllowed {
|
|
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestEstablishSAMLSession(t *testing.T) {
|
|
router := &Router{}
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
samlInfo := &SAMLSessionInfo{ProviderID: "okta", NameID: "user", SessionIndex: "sess-1"}
|
|
if err := router.establishSAMLSession(rec, req, "admin", samlInfo); err != nil {
|
|
t.Fatalf("establishSAMLSession error: %v", err)
|
|
}
|
|
|
|
cookies := rec.Result().Cookies()
|
|
if len(cookies) < 2 {
|
|
t.Fatalf("expected session and csrf cookies, got %d", len(cookies))
|
|
}
|
|
}
|
|
|
|
func TestRedirectSAMLError(t *testing.T) {
|
|
router := &Router{}
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.redirectSAMLError(rec, req, "/dashboard", "session_failed")
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
|
}
|
|
loc := rec.Header().Get("Location")
|
|
if !strings.Contains(loc, "saml=error") || !strings.Contains(loc, "saml_error=session_failed") {
|
|
t.Fatalf("unexpected redirect location %q", loc)
|
|
}
|
|
}
|
|
|
|
func TestInitializeSAMLProviders(t *testing.T) {
|
|
provider := testSAMLProvider("okta", true)
|
|
provider.SAML = &config.SAMLProviderConfig{
|
|
IDPSSOURL: "https://idp.example.com/sso",
|
|
IDPEntityID: "https://idp.example.com/metadata",
|
|
}
|
|
|
|
router := newSAMLRouter(t, provider)
|
|
if router.samlManager == nil {
|
|
router.samlManager = NewSAMLServiceManager("https://pulse.example.com")
|
|
}
|
|
|
|
if err := router.InitializeSAMLProviders(reqContext(t)); err != nil {
|
|
t.Fatalf("InitializeSAMLProviders error: %v", err)
|
|
}
|
|
if svc := router.samlManager.GetService("okta"); svc == nil {
|
|
t.Fatalf("expected SAML service to be initialized")
|
|
}
|
|
}
|