Files
Pulse/internal/api/saml_service.go
rcourtman 97701297c4 feat(sso): add SAML 2.0 and multi-provider SSO support
- Add SAML 2.0 Service Provider implementation using crewjam/saml
- Support IdP metadata from URL or raw XML
- Add multi-provider SSO configuration model
- Implement provider management API (CRUD operations)
- Add provider connection testing endpoint
- Add IdP metadata preview endpoint
- Add SSOProvidersPanel component for settings UI
- Support attribute-based role mapping (groups → Pulse roles)

API endpoints:
- GET/POST /api/security/sso/providers - List/create providers
- GET/PUT/DELETE /api/security/sso/providers/{id} - Provider CRUD
- POST /api/security/sso/providers/test - Test connection
- POST /api/security/sso/providers/metadata/preview - Preview metadata
- /api/saml/{id}/login, /acs, /metadata, /logout, /slo - SAML endpoints
2026-01-12 15:19:59 +00:00

610 lines
16 KiB
Go

package api
import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/crewjam/saml"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rs/zerolog/log"
)
// SAMLService manages SAML Service Provider functionality for a single provider
type SAMLService struct {
mu sync.RWMutex
providerID string
config *config.SAMLProviderConfig
sp *saml.ServiceProvider
idpMetadata *saml.EntityDescriptor
httpClient *http.Client
baseURL string
lastRefresh time.Time
}
// SAMLAuthResult contains the result of a successful SAML authentication
type SAMLAuthResult struct {
Username string
Email string
Groups []string
FirstName string
LastName string
NameID string
SessionIdx string
Attributes map[string][]string
}
// NewSAMLService creates a new SAML service for a provider
func NewSAMLService(ctx context.Context, providerID string, cfg *config.SAMLProviderConfig, baseURL string) (*SAMLService, error) {
if cfg == nil {
return nil, errors.New("saml configuration is nil")
}
service := &SAMLService{
providerID: providerID,
config: cfg,
baseURL: strings.TrimRight(baseURL, "/"),
httpClient: newSAMLHTTPClient(),
}
// Load IdP metadata
if err := service.loadIDPMetadata(ctx); err != nil {
return nil, fmt.Errorf("failed to load idp metadata: %w", err)
}
// Initialize Service Provider
if err := service.initServiceProvider(); err != nil {
return nil, fmt.Errorf("failed to initialize service provider: %w", err)
}
return service, nil
}
func newSAMLHTTPClient() *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
return &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
}
// loadIDPMetadata loads Identity Provider metadata from URL or XML
func (s *SAMLService) loadIDPMetadata(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
var metadata *saml.EntityDescriptor
var err error
if s.config.IDPMetadataURL != "" {
metadata, err = s.fetchIDPMetadataFromURL(ctx, s.config.IDPMetadataURL)
if err != nil {
return fmt.Errorf("failed to fetch idp metadata from url: %w", err)
}
} else if s.config.IDPMetadataXML != "" {
metadata, err = parseIDPMetadataXML([]byte(s.config.IDPMetadataXML))
if err != nil {
return fmt.Errorf("failed to parse idp metadata xml: %w", err)
}
} else {
// Build metadata from manual configuration
metadata, err = s.buildManualMetadata()
if err != nil {
return fmt.Errorf("failed to build manual metadata: %w", err)
}
}
s.idpMetadata = metadata
s.lastRefresh = time.Now()
log.Info().
Str("provider_id", s.providerID).
Str("entity_id", metadata.EntityID).
Msg("Loaded SAML IdP metadata")
return nil
}
func (s *SAMLService) fetchIDPMetadataFromURL(ctx context.Context, metadataURL string) (*saml.EntityDescriptor, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("metadata request returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
if err != nil {
return nil, err
}
return parseIDPMetadataXML(body)
}
func parseIDPMetadataXML(data []byte) (*saml.EntityDescriptor, error) {
var metadata saml.EntityDescriptor
if err := xml.Unmarshal(data, &metadata); err != nil {
// Try parsing as EntityDescriptor wrapped in EntitiesDescriptor
var entities saml.EntitiesDescriptor
if err2 := xml.Unmarshal(data, &entities); err2 != nil {
return nil, fmt.Errorf("failed to parse metadata: %w", err)
}
if len(entities.EntityDescriptors) == 0 {
return nil, errors.New("no entity descriptors found in metadata")
}
metadata = entities.EntityDescriptors[0]
}
return &metadata, nil
}
func (s *SAMLService) buildManualMetadata() (*saml.EntityDescriptor, error) {
if s.config.IDPSSOURL == "" {
return nil, errors.New("idp sso url is required for manual configuration")
}
ssoURL, err := url.Parse(s.config.IDPSSOURL)
if err != nil {
return nil, fmt.Errorf("invalid idp sso url: %w", err)
}
entityID := s.config.IDPEntityID
if entityID == "" {
entityID = s.config.IDPIssuer
}
if entityID == "" {
entityID = s.config.IDPSSOURL
}
metadata := &saml.EntityDescriptor{
EntityID: entityID,
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
SSODescriptor: saml.SSODescriptor{
RoleDescriptor: saml.RoleDescriptor{
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
},
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPRedirectBinding,
Location: ssoURL.String(),
},
{
Binding: saml.HTTPPostBinding,
Location: ssoURL.String(),
},
},
},
},
}
// Add SLO endpoint if configured
if s.config.IDPSLOUrl != "" {
sloURL, err := url.Parse(s.config.IDPSLOUrl)
if err == nil {
metadata.IDPSSODescriptors[0].SingleLogoutServices = []saml.Endpoint{
{
Binding: saml.HTTPRedirectBinding,
Location: sloURL.String(),
},
}
}
}
// Add IdP certificate if provided
if err := s.addIDPCertificate(metadata); err != nil {
return nil, err
}
return metadata, nil
}
func (s *SAMLService) addIDPCertificate(metadata *saml.EntityDescriptor) error {
var certData []byte
var err error
if s.config.IDPCertFile != "" {
certData, err = os.ReadFile(s.config.IDPCertFile)
if err != nil {
return fmt.Errorf("failed to read idp certificate file: %w", err)
}
} else if s.config.IDPCertificate != "" {
certData = []byte(s.config.IDPCertificate)
} else {
return nil // No certificate provided
}
// Parse PEM certificate
block, _ := pem.Decode(certData)
if block == nil {
return errors.New("failed to decode idp certificate pem")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse idp certificate: %w", err)
}
if len(metadata.IDPSSODescriptors) > 0 {
metadata.IDPSSODescriptors[0].KeyDescriptors = []saml.KeyDescriptor{
{
Use: "signing",
KeyInfo: saml.KeyInfo{
X509Data: saml.X509Data{
X509Certificates: []saml.X509Certificate{
{Data: string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))},
},
},
},
},
}
}
return nil
}
func (s *SAMLService) initServiceProvider() error {
// Build SP Entity ID
spEntityID := s.config.SPEntityID
if spEntityID == "" {
spEntityID = fmt.Sprintf("%s/saml/%s", s.baseURL, s.providerID)
}
// Build ACS URL
acsPath := s.config.SPACSPath
if acsPath == "" {
acsPath = fmt.Sprintf("/api/saml/%s/acs", s.providerID)
}
acsURL, err := url.Parse(s.baseURL + acsPath)
if err != nil {
return fmt.Errorf("failed to parse acs url: %w", err)
}
// Build Metadata URL
metadataPath := s.config.SPMetadataPath
if metadataPath == "" {
metadataPath = fmt.Sprintf("/api/saml/%s/metadata", s.providerID)
}
metadataURL, err := url.Parse(s.baseURL + metadataPath)
if err != nil {
return fmt.Errorf("failed to parse metadata url: %w", err)
}
forceAuthn := s.config.ForceAuthn
sp := saml.ServiceProvider{
EntityID: spEntityID,
AcsURL: *acsURL,
MetadataURL: *metadataURL,
IDPMetadata: s.idpMetadata,
AllowIDPInitiated: s.config.AllowIDPInitiated,
ForceAuthn: &forceAuthn,
}
// Set SLO URL if the IdP supports it
if len(s.idpMetadata.IDPSSODescriptors) > 0 &&
len(s.idpMetadata.IDPSSODescriptors[0].SingleLogoutServices) > 0 {
sloURL, err := url.Parse(s.baseURL + fmt.Sprintf("/api/saml/%s/slo", s.providerID))
if err == nil {
sp.SloURL = *sloURL
}
}
// Load SP certificate and key if signing is enabled
if s.config.SignRequests {
cert, key, err := s.loadSPCredentials()
if err != nil {
return fmt.Errorf("failed to load sp credentials: %w", err)
}
sp.Key = key
sp.Certificate = cert
}
s.sp = &sp
log.Info().
Str("provider_id", s.providerID).
Str("entity_id", spEntityID).
Str("acs_url", acsURL.String()).
Bool("sign_requests", s.config.SignRequests).
Msg("Initialized SAML Service Provider")
return nil
}
func (s *SAMLService) loadSPCredentials() (*x509.Certificate, *rsa.PrivateKey, error) {
var certData, keyData []byte
var err error
// Load certificate
if s.config.SPCertFile != "" {
certData, err = os.ReadFile(s.config.SPCertFile)
if err != nil {
return nil, nil, fmt.Errorf("failed to read sp certificate file: %w", err)
}
} else if s.config.SPCertificate != "" {
certData = []byte(s.config.SPCertificate)
} else {
return nil, nil, errors.New("sp certificate is required for signing")
}
// Load private key
if s.config.SPKeyFile != "" {
keyData, err = os.ReadFile(s.config.SPKeyFile)
if err != nil {
return nil, nil, fmt.Errorf("failed to read sp private key file: %w", err)
}
} else if s.config.SPPrivateKey != "" {
keyData = []byte(s.config.SPPrivateKey)
} else {
return nil, nil, errors.New("sp private key is required for signing")
}
// Parse certificate
certBlock, _ := pem.Decode(certData)
if certBlock == nil {
return nil, nil, errors.New("failed to decode sp certificate pem")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse sp certificate: %w", err)
}
// Parse private key
keyBlock, _ := pem.Decode(keyData)
if keyBlock == nil {
return nil, nil, errors.New("failed to decode sp private key pem")
}
var key *rsa.PrivateKey
switch keyBlock.Type {
case "RSA PRIVATE KEY":
key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "PRIVATE KEY":
parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse pkcs8 private key: %w", err)
}
var ok bool
key, ok = parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, nil, errors.New("sp private key is not rsa")
}
default:
return nil, nil, fmt.Errorf("unsupported private key type: %s", keyBlock.Type)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to parse sp private key: %w", err)
}
return cert, key, nil
}
// MakeAuthRequest creates a SAML AuthnRequest and returns the redirect URL
func (s *SAMLService) MakeAuthRequest(relayState string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.sp == nil {
return "", errors.New("service provider not initialized")
}
if relayState == "" {
relayState = "/"
}
// Use the simple redirect method
redirectURL, err := s.sp.MakeRedirectAuthenticationRequest(relayState)
if err != nil {
return "", fmt.Errorf("failed to create auth request: %w", err)
}
log.Debug().
Str("provider_id", s.providerID).
Str("redirect_url", redirectURL.String()).
Msg("Created SAML AuthnRequest")
return redirectURL.String(), nil
}
// ProcessResponse processes a SAML response and extracts user information
func (s *SAMLService) ProcessResponse(r *http.Request) (*SAMLAuthResult, string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.sp == nil {
return nil, "", errors.New("service provider not initialized")
}
// Parse the form to get SAMLResponse and RelayState
if err := r.ParseForm(); err != nil {
return nil, "", fmt.Errorf("failed to parse form: %w", err)
}
relayState := r.FormValue("RelayState")
// Allow IdP-initiated flow
possibleRequestIDs := []string{}
if s.sp.AllowIDPInitiated {
possibleRequestIDs = append(possibleRequestIDs, "")
}
// Parse and validate the SAML assertion
assertion, err := s.sp.ParseResponse(r, possibleRequestIDs)
if err != nil {
return nil, relayState, fmt.Errorf("failed to validate saml response: %w", err)
}
// Extract user information from assertion
result := &SAMLAuthResult{
Attributes: make(map[string][]string),
}
// Get NameID
if assertion.Subject != nil && assertion.Subject.NameID != nil {
result.NameID = assertion.Subject.NameID.Value
}
// Get session index from AuthnStatement
for _, authnStatement := range assertion.AuthnStatements {
if authnStatement.SessionIndex != "" {
result.SessionIdx = authnStatement.SessionIndex
break
}
}
// Extract attributes
for _, statement := range assertion.AttributeStatements {
for _, attr := range statement.Attributes {
values := make([]string, 0, len(attr.Values))
for _, v := range attr.Values {
values = append(values, v.Value)
}
result.Attributes[attr.Name] = values
// Also try FriendlyName
if attr.FriendlyName != "" {
result.Attributes[attr.FriendlyName] = values
}
}
}
// Extract specific attributes based on configuration
result.Username = s.extractAttribute(result.Attributes, s.config.UsernameAttr, result.NameID)
result.Email = s.extractAttribute(result.Attributes, s.config.EmailAttr, "")
result.FirstName = s.extractAttribute(result.Attributes, s.config.FirstNameAttr, "")
result.LastName = s.extractAttribute(result.Attributes, s.config.LastNameAttr, "")
// Extract groups
if s.config.GroupsAttr != "" {
if groups, ok := result.Attributes[s.config.GroupsAttr]; ok {
result.Groups = groups
}
}
log.Info().
Str("provider_id", s.providerID).
Str("username", result.Username).
Str("email", result.Email).
Int("groups", len(result.Groups)).
Msg("Processed SAML assertion")
return result, relayState, nil
}
func (s *SAMLService) extractAttribute(attrs map[string][]string, attrName, defaultValue string) string {
if attrName == "" {
return defaultValue
}
if vals, ok := attrs[attrName]; ok && len(vals) > 0 {
return vals[0]
}
return defaultValue
}
// GetMetadata returns the SP metadata XML
func (s *SAMLService) GetMetadata() ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.sp == nil {
return nil, errors.New("service provider not initialized")
}
metadata := s.sp.Metadata()
return xml.MarshalIndent(metadata, "", " ")
}
// MakeLogoutRequest creates a SAML LogoutRequest for SLO
func (s *SAMLService) MakeLogoutRequest(nameID, sessionIdx string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.sp == nil {
return "", errors.New("service provider not initialized")
}
// Check if IdP supports SLO
if len(s.idpMetadata.IDPSSODescriptors) == 0 ||
len(s.idpMetadata.IDPSSODescriptors[0].SingleLogoutServices) == 0 {
return "", errors.New("idp does not support single logout")
}
sloService := s.idpMetadata.IDPSSODescriptors[0].SingleLogoutServices[0]
req, err := s.sp.MakeLogoutRequest(sloService.Location, nameID)
if err != nil {
return "", fmt.Errorf("failed to create logout request: %w", err)
}
// Build redirect URL
redirectURL := req.Redirect("")
return redirectURL.String(), nil
}
// RefreshMetadata reloads IdP metadata (useful for key rotation)
func (s *SAMLService) RefreshMetadata(ctx context.Context) error {
if s.config.IDPMetadataURL == "" {
return errors.New("cannot refresh metadata without url")
}
if err := s.loadIDPMetadata(ctx); err != nil {
return err
}
// Reinitialize SP with new metadata
return s.initServiceProvider()
}
// ProviderID returns the provider identifier
func (s *SAMLService) ProviderID() string {
return s.providerID
}
// GetSPEntityID returns the Service Provider Entity ID
func (s *SAMLService) GetSPEntityID() string {
s.mu.RLock()
defer s.mu.RUnlock()
if s.sp == nil {
return ""
}
return s.sp.EntityID
}
// GetIDPEntityID returns the Identity Provider Entity ID
func (s *SAMLService) GetIDPEntityID() string {
s.mu.RLock()
defer s.mu.RUnlock()
if s.idpMetadata == nil {
return ""
}
return s.idpMetadata.EntityID
}