mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-19 07:50:43 +01:00
- Add Access-Control-Expose-Headers to allow frontend to read X-CSRF-Token response header - Implement proactive CSRF token issuance on GET requests when session exists but CSRF cookie is missing - Ensures frontend always has valid CSRF token before making POST requests - Fixes 403 Forbidden errors when toggling system settings This resolves CSRF validation failures that occurred when CSRF tokens expired or were missing while valid sessions existed.
567 lines
13 KiB
Go
567 lines
13 KiB
Go
package api
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// Security improvements for Pulse
|
|
|
|
// generateCSRFToken creates a new CSRF token for a session
|
|
func generateCSRFToken(sessionID string) string {
|
|
return GetCSRFStore().GenerateCSRFToken(sessionID)
|
|
}
|
|
|
|
// validateCSRFToken checks if a CSRF token is valid for a session
|
|
func validateCSRFToken(sessionID, token string) bool {
|
|
return GetCSRFStore().ValidateCSRFToken(sessionID, token)
|
|
}
|
|
|
|
// CheckCSRF validates CSRF token for state-changing requests
|
|
func CheckCSRF(w http.ResponseWriter, r *http.Request) bool {
|
|
// Skip CSRF check for safe methods
|
|
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
|
return true
|
|
}
|
|
|
|
// Skip CSRF for API token auth (API clients don't have sessions)
|
|
if r.Header.Get("X-API-Token") != "" {
|
|
return true
|
|
}
|
|
|
|
// Skip CSRF for Basic Auth (doesn't use sessions, not vulnerable to CSRF)
|
|
if r.Header.Get("Authorization") != "" {
|
|
return true
|
|
}
|
|
|
|
// Get session from cookie
|
|
cookie, err := r.Cookie("pulse_session")
|
|
if err != nil {
|
|
// No session cookie means no CSRF check needed
|
|
// (either no auth configured or using basic auth which doesn't use sessions)
|
|
return true
|
|
}
|
|
|
|
// Get CSRF token from header or form
|
|
csrfToken := r.Header.Get("X-CSRF-Token")
|
|
if csrfToken == "" {
|
|
csrfToken = r.FormValue("csrf_token")
|
|
}
|
|
|
|
// No CSRF token means request is not eligible for mutation
|
|
if csrfToken == "" {
|
|
log.Warn().
|
|
Str("path", r.URL.Path).
|
|
Str("session", cookie.Value[:8]+"...").
|
|
Msg("Missing CSRF token")
|
|
clearCSRFCookie(w)
|
|
if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" {
|
|
w.Header().Set("X-CSRF-Token", newToken)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Check if the CSRF token validates
|
|
if !validateCSRFToken(cookie.Value, csrfToken) {
|
|
log.Warn().
|
|
Str("path", r.URL.Path).
|
|
Str("session", cookie.Value[:8]+"...").
|
|
Str("provided_token", csrfToken[:8]+"...").
|
|
Msg("Invalid CSRF token")
|
|
clearCSRFCookie(w)
|
|
if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" {
|
|
w.Header().Set("X-CSRF-Token", newToken)
|
|
}
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func clearCSRFCookie(w http.ResponseWriter) {
|
|
if w == nil {
|
|
return
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "pulse_csrf",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: false,
|
|
})
|
|
}
|
|
|
|
func issueNewCSRFCookie(w http.ResponseWriter, r *http.Request, sessionID string) string {
|
|
if w == nil || r == nil {
|
|
return ""
|
|
}
|
|
if strings.TrimSpace(sessionID) == "" {
|
|
return ""
|
|
}
|
|
|
|
newToken := generateCSRFToken(sessionID)
|
|
secure, sameSite := getCookieSettings(r)
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "pulse_csrf",
|
|
Value: newToken,
|
|
Path: "/",
|
|
Secure: secure,
|
|
SameSite: sameSite,
|
|
MaxAge: 86400,
|
|
})
|
|
return newToken
|
|
}
|
|
|
|
// Rate Limiting - using existing RateLimiter from ratelimit.go
|
|
var (
|
|
// Auth endpoints: 10 attempts per minute
|
|
authLimiter = NewRateLimiter(10, 1*time.Minute)
|
|
|
|
// General API: 500 requests per minute (increased for metadata endpoints)
|
|
apiLimiter = NewRateLimiter(500, 1*time.Minute)
|
|
)
|
|
|
|
// GetClientIP extracts the client IP from the request
|
|
func GetClientIP(r *http.Request) string {
|
|
rawRemoteIP := extractRemoteIP(r.RemoteAddr)
|
|
if rawRemoteIP == "" {
|
|
return ""
|
|
}
|
|
|
|
// Only trust proxy headers when the immediate peer is trusted.
|
|
if isTrustedProxyIP(rawRemoteIP) {
|
|
if forwarded := firstValidForwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" {
|
|
return forwarded
|
|
}
|
|
|
|
if realIP := strings.TrimSpace(strings.Trim(r.Header.Get("X-Real-IP"), "[]")); realIP != "" && net.ParseIP(realIP) != nil {
|
|
return realIP
|
|
}
|
|
}
|
|
|
|
return rawRemoteIP
|
|
}
|
|
|
|
// Failed Login Tracking
|
|
type FailedLogin struct {
|
|
Count int
|
|
LastAttempt time.Time
|
|
LockedUntil time.Time
|
|
}
|
|
|
|
var (
|
|
failedLogins = make(map[string]*FailedLogin)
|
|
failedMu sync.RWMutex
|
|
|
|
maxFailedAttempts = 5
|
|
lockoutDuration = 15 * time.Minute
|
|
|
|
trustedProxyOnce sync.Once
|
|
trustedProxyCIDRs []*net.IPNet
|
|
)
|
|
|
|
func loadTrustedProxyCIDRs() {
|
|
raw := utils.GetenvTrim("PULSE_TRUSTED_PROXY_CIDRS")
|
|
if raw == "" {
|
|
return
|
|
}
|
|
|
|
for _, entry := range strings.Split(raw, ",") {
|
|
entry = strings.TrimSpace(entry)
|
|
if entry == "" {
|
|
continue
|
|
}
|
|
|
|
if strings.Contains(entry, "/") {
|
|
_, network, parseErr := net.ParseCIDR(entry)
|
|
if parseErr == nil {
|
|
network.IP = network.IP.Mask(network.Mask)
|
|
trustedProxyCIDRs = append(trustedProxyCIDRs, network)
|
|
continue
|
|
}
|
|
log.Warn().
|
|
Str("cidr", entry).
|
|
Err(parseErr).
|
|
Msg("Ignoring invalid CIDR in PULSE_TRUSTED_PROXY_CIDRS")
|
|
continue
|
|
}
|
|
|
|
ip := net.ParseIP(entry)
|
|
if ip == nil {
|
|
log.Warn().
|
|
Str("value", entry).
|
|
Msg("Ignoring invalid IP in PULSE_TRUSTED_PROXY_CIDRS")
|
|
continue
|
|
}
|
|
|
|
bits := 32
|
|
if ip.To4() == nil {
|
|
bits = 128
|
|
}
|
|
mask := net.CIDRMask(bits, bits)
|
|
network := &net.IPNet{IP: ip.Mask(mask), Mask: mask}
|
|
trustedProxyCIDRs = append(trustedProxyCIDRs, network)
|
|
}
|
|
}
|
|
|
|
func extractRemoteIP(remoteAddr string) string {
|
|
if remoteAddr == "" {
|
|
return ""
|
|
}
|
|
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
|
return strings.Trim(host, "[]")
|
|
}
|
|
return strings.Trim(remoteAddr, "[]")
|
|
}
|
|
|
|
func firstValidForwardedIP(header string) string {
|
|
if header == "" {
|
|
return ""
|
|
}
|
|
for _, part := range strings.Split(header, ",") {
|
|
part = strings.TrimSpace(strings.Trim(part, "[]"))
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
if net.ParseIP(part) != nil {
|
|
return part
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func isTrustedProxyIP(ipStr string) bool {
|
|
ipStr = strings.TrimSpace(strings.Trim(ipStr, "[]"))
|
|
if ipStr == "" {
|
|
return false
|
|
}
|
|
ip := net.ParseIP(ipStr)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
|
|
trustedProxyOnce.Do(loadTrustedProxyCIDRs)
|
|
if len(trustedProxyCIDRs) == 0 {
|
|
return false
|
|
}
|
|
for _, network := range trustedProxyCIDRs {
|
|
if network.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isPrivateIP(ip string) bool {
|
|
host := extractRemoteIP(ip)
|
|
if host == "" {
|
|
return false
|
|
}
|
|
|
|
parsedIP := net.ParseIP(host)
|
|
if parsedIP == nil {
|
|
return false
|
|
}
|
|
|
|
if parsedIP.IsLoopback() ||
|
|
parsedIP.IsLinkLocalUnicast() ||
|
|
parsedIP.IsLinkLocalMulticast() {
|
|
return true
|
|
}
|
|
|
|
privateRanges := []string{
|
|
"10.0.0.0/8",
|
|
"172.16.0.0/12",
|
|
"192.168.0.0/16",
|
|
"127.0.0.0/8",
|
|
"::1/128",
|
|
"fc00::/7",
|
|
"fe80::/10",
|
|
}
|
|
|
|
for _, cidr := range privateRanges {
|
|
_, network, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if network.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isTrustedNetwork(ip string, trustedNetworks []string) bool {
|
|
if len(trustedNetworks) == 0 {
|
|
return isPrivateIP(ip)
|
|
}
|
|
|
|
host := extractRemoteIP(ip)
|
|
if host == "" {
|
|
return false
|
|
}
|
|
|
|
parsedIP := net.ParseIP(host)
|
|
if parsedIP == nil {
|
|
return false
|
|
}
|
|
|
|
for _, cidr := range trustedNetworks {
|
|
_, network, err := net.ParseCIDR(strings.TrimSpace(cidr))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if network.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// RecordFailedLogin tracks failed login attempts
|
|
func RecordFailedLogin(identifier string) {
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
|
|
failed, exists := failedLogins[identifier]
|
|
if !exists {
|
|
failed = &FailedLogin{}
|
|
failedLogins[identifier] = failed
|
|
}
|
|
|
|
failed.Count++
|
|
failed.LastAttempt = time.Now()
|
|
|
|
if failed.Count >= maxFailedAttempts {
|
|
failed.LockedUntil = time.Now().Add(lockoutDuration)
|
|
log.Warn().
|
|
Str("identifier", identifier).
|
|
Int("attempts", failed.Count).
|
|
Time("locked_until", failed.LockedUntil).
|
|
Msg("Account locked due to failed login attempts")
|
|
}
|
|
}
|
|
|
|
// ClearFailedLogins resets failed login counter on successful login
|
|
func ClearFailedLogins(identifier string) {
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
delete(failedLogins, identifier)
|
|
}
|
|
|
|
// IsLockedOut checks if an account is locked out
|
|
func IsLockedOut(identifier string) bool {
|
|
failedMu.RLock()
|
|
defer failedMu.RUnlock()
|
|
|
|
failed, exists := failedLogins[identifier]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
if time.Now().After(failed.LockedUntil) {
|
|
// Lockout expired
|
|
return false
|
|
}
|
|
|
|
return failed.Count >= maxFailedAttempts
|
|
}
|
|
|
|
// GetLockoutInfo returns lockout information for an identifier
|
|
func GetLockoutInfo(identifier string) (attempts int, lockedUntil time.Time, isLocked bool) {
|
|
failedMu.RLock()
|
|
defer failedMu.RUnlock()
|
|
|
|
failed, exists := failedLogins[identifier]
|
|
if !exists {
|
|
return 0, time.Time{}, false
|
|
}
|
|
|
|
// Check if lockout has expired
|
|
if time.Now().After(failed.LockedUntil) && failed.Count >= maxFailedAttempts {
|
|
// Lockout expired, treat as no attempts
|
|
return 0, time.Time{}, false
|
|
}
|
|
|
|
isLocked = failed.Count >= maxFailedAttempts && time.Now().Before(failed.LockedUntil)
|
|
return failed.Count, failed.LockedUntil, isLocked
|
|
}
|
|
|
|
// ResetLockout manually resets lockout for an identifier (admin function)
|
|
func ResetLockout(identifier string) {
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
delete(failedLogins, identifier)
|
|
|
|
log.Info().
|
|
Str("identifier", identifier).
|
|
Msg("Lockout manually reset")
|
|
}
|
|
|
|
// Security Headers Middleware
|
|
func SecurityHeaders(next http.Handler) http.Handler {
|
|
return SecurityHeadersWithConfig(next, false, "")
|
|
}
|
|
|
|
// SecurityHeadersWithConfig applies security headers with embedding configuration
|
|
func SecurityHeadersWithConfig(next http.Handler, allowEmbedding bool, allowedOrigins string) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Configure clickjacking protection based on embedding settings
|
|
if allowEmbedding {
|
|
// When embedding is allowed, don't set X-Frame-Options header
|
|
// This allows embedding from any origin
|
|
// Security note: User explicitly enabled this for iframe embedding
|
|
} else {
|
|
// Deny all embedding when not explicitly allowed
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
}
|
|
|
|
// Prevent MIME type sniffing
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
|
|
// Enable XSS protection
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
|
|
// Build Content Security Policy
|
|
cspDirectives := []string{
|
|
"default-src 'self'",
|
|
"script-src 'self' 'unsafe-inline' 'unsafe-eval'", // Needed for React
|
|
"style-src 'self' 'unsafe-inline'", // Needed for inline styles
|
|
"img-src 'self' data: blob:",
|
|
"connect-src 'self' ws: wss:", // WebSocket support
|
|
"font-src 'self' data:",
|
|
}
|
|
|
|
// Add frame-ancestors based on embedding settings
|
|
if allowEmbedding {
|
|
if allowedOrigins != "" {
|
|
// Parse comma-separated origins and add them to frame-ancestors
|
|
origins := strings.Split(allowedOrigins, ",")
|
|
frameAncestors := "frame-ancestors 'self'"
|
|
for _, origin := range origins {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin != "" {
|
|
frameAncestors += " " + origin
|
|
}
|
|
}
|
|
cspDirectives = append(cspDirectives, frameAncestors)
|
|
} else {
|
|
// Allow embedding from any origin (user explicitly enabled this)
|
|
cspDirectives = append(cspDirectives, "frame-ancestors *")
|
|
}
|
|
} else {
|
|
// Deny all embedding
|
|
cspDirectives = append(cspDirectives, "frame-ancestors 'none'")
|
|
}
|
|
|
|
w.Header().Set("Content-Security-Policy", strings.Join(cspDirectives, "; "))
|
|
|
|
// Referrer Policy
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Permissions Policy (formerly Feature Policy)
|
|
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Audit Logging
|
|
type AuditEvent struct {
|
|
Timestamp time.Time `json:"timestamp"`
|
|
Event string `json:"event"`
|
|
User string `json:"user,omitempty"`
|
|
IP string `json:"ip"`
|
|
Path string `json:"path,omitempty"`
|
|
Success bool `json:"success"`
|
|
Details string `json:"details,omitempty"`
|
|
}
|
|
|
|
// LogAuditEvent logs security-relevant events
|
|
func LogAuditEvent(event string, user string, ip string, path string, success bool, details string) {
|
|
if success {
|
|
log.Info().
|
|
Str("event", event).
|
|
Str("user", user).
|
|
Str("ip", ip).
|
|
Str("path", path).
|
|
Str("details", details).
|
|
Time("timestamp", time.Now()).
|
|
Msg("Security audit event")
|
|
} else {
|
|
log.Warn().
|
|
Str("event", event).
|
|
Str("user", user).
|
|
Str("ip", ip).
|
|
Str("path", path).
|
|
Str("details", details).
|
|
Time("timestamp", time.Now()).
|
|
Msg("Security audit event - FAILED")
|
|
}
|
|
}
|
|
|
|
// Session Management Improvements
|
|
var (
|
|
allSessions = make(map[string][]string) // user -> []sessionIDs
|
|
sessionsMu sync.RWMutex
|
|
)
|
|
|
|
// TrackUserSession tracks which sessions belong to which user
|
|
func TrackUserSession(user, sessionID string) {
|
|
sessionsMu.Lock()
|
|
defer sessionsMu.Unlock()
|
|
|
|
if allSessions[user] == nil {
|
|
allSessions[user] = []string{}
|
|
}
|
|
allSessions[user] = append(allSessions[user], sessionID)
|
|
}
|
|
|
|
// GetSessionUsername returns the username associated with a session ID
|
|
func GetSessionUsername(sessionID string) string {
|
|
sessionsMu.RLock()
|
|
defer sessionsMu.RUnlock()
|
|
|
|
for user, sessions := range allSessions {
|
|
for _, sid := range sessions {
|
|
if sid == sessionID {
|
|
return user
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// InvalidateUserSessions invalidates all sessions for a user (e.g., on password change)
|
|
func InvalidateUserSessions(user string) {
|
|
sessionsMu.Lock()
|
|
defer sessionsMu.Unlock()
|
|
|
|
sessionIDs := allSessions[user]
|
|
for _, sid := range sessionIDs {
|
|
// Delete from persistent session store
|
|
GetSessionStore().DeleteSession(sid)
|
|
|
|
// Delete CSRF tokens
|
|
GetCSRFStore().DeleteCSRFToken(sid)
|
|
}
|
|
|
|
delete(allSessions, user)
|
|
|
|
log.Info().
|
|
Str("user", user).
|
|
Int("sessions_invalidated", len(sessionIDs)).
|
|
Msg("Invalidated all user sessions")
|
|
}
|