Files
Pulse/internal/api/oidc_service.go
rcourtman b6140cd6e8 feat(oidc): Add refresh token support for long-lived sessions
When offline_access scope is configured, Pulse now stores and uses
OIDC refresh tokens to automatically extend sessions. Sessions remain
valid as long as the IdP allows token refresh (typically 30-90 days).

Changes:
- Store OIDC tokens (refresh token, expiry, issuer) alongside sessions
- Automatically refresh tokens when access token nears expiry
- Invalidate session if IdP revokes access (forces re-login)
- Add background token refresh with concurrency protection
- Persist OIDC tokens across restarts

Related to #854
2025-12-20 10:45:46 +00:00

417 lines
9.9 KiB
Go

package api
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
)
// oidcStateTTL defines how long we accept OIDC login attempts before expiring the state entry.
const oidcStateTTL = 10 * time.Minute
// OIDCService caches provider metadata and manages transient state for authorization flows.
type OIDCService struct {
snapshot oidcSnapshot
provider *oidc.Provider
oauth2Cfg *oauth2.Config
verifier *oidc.IDTokenVerifier
stateStore *oidcStateStore
httpClient *http.Client
}
type oidcSnapshot struct {
issuer string
clientID string
clientSecret string
redirectURL string
scopes []string
caBundle string
caBundleHash string
}
// NewOIDCService fetches provider metadata and prepares helper structures.
func NewOIDCService(ctx context.Context, cfg *config.OIDCConfig) (*OIDCService, error) {
if cfg == nil || !cfg.Enabled {
return nil, errors.New("oidc is not enabled")
}
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
log.Debug().
Str("issuer", cfg.IssuerURL).
Str("redirect_url", cfg.RedirectURL).
Strs("scopes", cfg.Scopes).
Str("ca_bundle", cfg.CABundle).
Msg("Initializing OIDC provider")
httpClient, caHash, err := newOIDCHTTPClient(cfg.CABundle)
if err != nil {
return nil, fmt.Errorf("failed to build OIDC HTTP client: %w", err)
}
ctx = oidc.ClientContext(ctx, httpClient)
provider, err := oidc.NewProvider(ctx, cfg.IssuerURL)
if err != nil {
return nil, fmt.Errorf("failed to discover OIDC provider: %w", err)
}
log.Debug().
Str("issuer", cfg.IssuerURL).
Str("auth_endpoint", provider.Endpoint().AuthURL).
Str("token_endpoint", provider.Endpoint().TokenURL).
Msg("OIDC provider discovery successful")
oauth2Cfg := &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
RedirectURL: cfg.RedirectURL,
Endpoint: provider.Endpoint(),
Scopes: append([]string{}, cfg.Scopes...),
}
verifier := provider.Verifier(&oidc.Config{ClientID: cfg.ClientID})
snapshot := oidcSnapshot{
issuer: cfg.IssuerURL,
clientID: cfg.ClientID,
clientSecret: cfg.ClientSecret,
redirectURL: cfg.RedirectURL,
scopes: append([]string{}, cfg.Scopes...),
caBundle: cfg.CABundle,
caBundleHash: caHash,
}
service := &OIDCService{
snapshot: snapshot,
provider: provider,
oauth2Cfg: oauth2Cfg,
verifier: verifier,
stateStore: newOIDCStateStore(),
httpClient: httpClient,
}
return service, nil
}
// Matches checks whether the cached configuration matches the provided settings.
func (s *OIDCService) Matches(cfg *config.OIDCConfig) bool {
if s == nil || cfg == nil {
return false
}
if s.snapshot.issuer != cfg.IssuerURL {
return false
}
if s.snapshot.clientID != cfg.ClientID {
return false
}
if s.snapshot.clientSecret != cfg.ClientSecret {
return false
}
if s.snapshot.redirectURL != cfg.RedirectURL {
return false
}
if s.snapshot.caBundle != cfg.CABundle {
return false
}
if cfg.CABundle != "" {
currentHash, err := hashCABundle(cfg.CABundle)
if err != nil {
return false
}
if s.snapshot.caBundleHash != currentHash {
return false
}
}
if len(s.snapshot.scopes) != len(cfg.Scopes) {
return false
}
for i, scope := range s.snapshot.scopes {
if scope != cfg.Scopes[i] {
return false
}
}
return true
}
func (s *OIDCService) newStateEntry(returnTo string) (string, *oidcStateEntry, error) {
state, err := generateRandomURLString(32)
if err != nil {
return "", nil, err
}
nonce, err := generateRandomURLString(32)
if err != nil {
return "", nil, err
}
codeVerifier, codeChallenge, err := generatePKCEPair()
if err != nil {
return "", nil, err
}
entry := &oidcStateEntry{
Nonce: nonce,
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
ReturnTo: returnTo,
ExpiresAt: time.Now().Add(oidcStateTTL),
}
s.stateStore.Put(state, entry)
return state, entry, nil
}
func (s *OIDCService) consumeState(state string) (*oidcStateEntry, bool) {
return s.stateStore.Consume(state)
}
func (s *OIDCService) authCodeURL(state string, entry *oidcStateEntry) string {
opts := []oauth2.AuthCodeOption{oidc.Nonce(entry.Nonce)}
if entry.CodeChallenge != "" {
opts = append(opts,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", entry.CodeChallenge),
)
}
return s.oauth2Cfg.AuthCodeURL(state, opts...)
}
func (s *OIDCService) exchangeCode(ctx context.Context, code string, entry *oidcStateEntry) (*oauth2.Token, error) {
ctx = s.contextWithHTTPClient(ctx)
opts := []oauth2.AuthCodeOption{}
if entry.CodeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", entry.CodeVerifier))
}
return s.oauth2Cfg.Exchange(ctx, code, opts...)
}
func (s *OIDCService) contextWithHTTPClient(ctx context.Context) context.Context {
if s.httpClient == nil {
return ctx
}
return oidc.ClientContext(ctx, s.httpClient)
}
// OIDCRefreshResult contains the result of a token refresh operation
type OIDCRefreshResult struct {
AccessToken string
RefreshToken string
Expiry time.Time
}
// RefreshToken uses the refresh token to obtain new access and refresh tokens from the IdP
func (s *OIDCService) RefreshToken(ctx context.Context, refreshToken string) (*OIDCRefreshResult, error) {
if refreshToken == "" {
return nil, errors.New("no refresh token provided")
}
ctx = s.contextWithHTTPClient(ctx)
// Create a token source from the refresh token
token := &oauth2.Token{
RefreshToken: refreshToken,
// Set expiry in the past to force refresh
Expiry: time.Now().Add(-time.Hour),
}
tokenSource := s.oauth2Cfg.TokenSource(ctx, token)
// This will trigger a refresh since the token is expired
newToken, err := tokenSource.Token()
if err != nil {
log.Warn().Err(err).Msg("OIDC token refresh failed")
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
result := &OIDCRefreshResult{
AccessToken: newToken.AccessToken,
Expiry: newToken.Expiry,
}
// The new refresh token might be the same or different depending on the IdP
if newToken.RefreshToken != "" {
result.RefreshToken = newToken.RefreshToken
} else {
// Keep the old refresh token if a new one wasn't issued
result.RefreshToken = refreshToken
}
log.Debug().
Time("new_expiry", result.Expiry).
Bool("new_refresh_token", newToken.RefreshToken != "").
Msg("OIDC token refresh successful")
return result, nil
}
func newOIDCHTTPClient(caBundle string) (*http.Client, string, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
var clone *http.Transport
if ok && transport != nil {
clone = transport.Clone()
} else {
clone = &http.Transport{
Proxy: http.ProxyFromEnvironment,
}
}
if strings.TrimSpace(caBundle) == "" {
return &http.Client{Transport: clone}, "", nil
}
caData, err := os.ReadFile(caBundle)
if err != nil {
return nil, "", fmt.Errorf("failed to read OIDC CA bundle: %w", err)
}
pool, err := x509.SystemCertPool()
if err != nil || pool == nil {
pool = x509.NewCertPool()
}
if ok := pool.AppendCertsFromPEM(caData); !ok {
return nil, "", fmt.Errorf("OIDC CA bundle does not contain any certificates")
}
if clone.TLSClientConfig == nil {
clone.TLSClientConfig = &tls.Config{}
}
clone.TLSClientConfig.MinVersion = tls.VersionTLS12
clone.TLSClientConfig.RootCAs = pool
sum := sha256.Sum256(caData)
caHash := fmt.Sprintf("%x", sum[:])
return &http.Client{Transport: clone}, caHash, nil
}
func hashCABundle(path string) (string, error) {
if strings.TrimSpace(path) == "" {
return "", nil
}
data, err := os.ReadFile(path)
if err != nil {
return "", err
}
sum := sha256.Sum256(data)
return fmt.Sprintf("%x", sum[:]), nil
}
// oidcStateStore keeps short-lived authorization state tokens.
type oidcStateStore struct {
mu sync.RWMutex
entries map[string]*oidcStateEntry
stopCleanup chan struct{}
}
type oidcStateEntry struct {
Nonce string
CodeVerifier string
CodeChallenge string
ReturnTo string
ExpiresAt time.Time
}
func newOIDCStateStore() *oidcStateStore {
s := &oidcStateStore{
entries: make(map[string]*oidcStateEntry),
stopCleanup: make(chan struct{}),
}
// Start cleanup routine to prevent memory leak from abandoned OIDC flows
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.cleanup()
case <-s.stopCleanup:
return
}
}
}()
return s
}
// cleanup removes expired state entries
func (s *oidcStateStore) cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for state, entry := range s.entries {
if now.After(entry.ExpiresAt) {
delete(s.entries, state)
}
}
}
// Stop stops the cleanup routine
func (s *oidcStateStore) Stop() {
close(s.stopCleanup)
}
func (s *oidcStateStore) Put(state string, entry *oidcStateEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries[state] = entry
}
func (s *oidcStateStore) Consume(state string) (*oidcStateEntry, bool) {
s.mu.Lock()
defer s.mu.Unlock()
entry, exists := s.entries[state]
if !exists {
return nil, false
}
delete(s.entries, state)
if time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry, true
}
func generateRandomURLString(size int) (string, error) {
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func generatePKCEPair() (verifier string, challenge string, err error) {
buf := make([]byte, 32)
if _, err = rand.Read(buf); err != nil {
return "", "", err
}
verifier = base64.RawURLEncoding.EncodeToString(buf)
hash := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(hash[:])
return verifier, challenge, nil
}