diff --git a/internal/ai/service.go b/internal/ai/service.go index b7869cd53..22e0779a9 100644 --- a/internal/ai/service.go +++ b/internal/ai/service.go @@ -3000,13 +3000,39 @@ func (s *Service) fetchURL(ctx context.Context, urlStr string) (string, error) { return "", err } - // Create HTTP client with timeout + // Create HTTP client with timeout and safe transport to prevent SSRF/DNS rebinding client := &http.Client{ Timeout: 30 * time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Use a dialer with reasonable timeout + d := net.Dialer{ + Timeout: 10 * time.Second, + } + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + // Validate the actual connected IP address + // This prevents DNS rebinding attacks + if tcpConn, ok := conn.(*net.TCPConn); ok { + remoteAddr := tcpConn.RemoteAddr().(*net.TCPAddr) + if isBlockedFetchIP(remoteAddr.IP) { + conn.Close() + return nil, fmt.Errorf("URL resolves to blocked IP address: %s", remoteAddr.IP) + } + } + + return conn, nil + }, + DisableKeepAlives: true, // One-off requests + }, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 3 { return fmt.Errorf("too many redirects") } + // Still validate the URL structure and initial resolution for failsafe if _, err := parseAndValidateFetchURL(ctx, req.URL.String()); err != nil { return err } @@ -3117,13 +3143,23 @@ func isBlockedFetchIP(ip net.IP) bool { return true } if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + // Allow loopback only if explicitly permitted (for local development) if ip.IsLoopback() && os.Getenv("PULSE_AI_ALLOW_LOOPBACK") == "true" { return false } return true } + // SECURITY: Block private IP ranges (RFC1918) to prevent SSRF attacks + // Private ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 + if ip.IsPrivate() { + // Allow private IPs only if explicitly permitted + if os.Getenv("PULSE_AI_ALLOW_PRIVATE_IPS") == "true" { + return false + } + return true + } // Block multicast and other non-unicast targets. - if !ip.IsGlobalUnicast() && !ip.IsPrivate() { + if !ip.IsGlobalUnicast() { return true } return false diff --git a/internal/ai/service_tools_test.go b/internal/ai/service_tools_test.go index 23d163714..eecc53468 100644 --- a/internal/ai/service_tools_test.go +++ b/internal/ai/service_tools_test.go @@ -93,9 +93,11 @@ func TestIsBlockedFetchIP(t *testing.T) { {"::1", true}, {"0.0.0.0", true}, {"169.254.1.1", true}, - {"192.168.1.1", false}, // Private is allowed - {"8.8.8.8", false}, // Global is allowed - {"224.0.0.1", true}, // Multicast + {"192.168.1.1", true}, // Private IPs are blocked by default for security (SSRF prevention) + {"10.0.0.1", true}, // Private range 10.x.x.x blocked + {"172.16.0.1", true}, // Private range 172.16.x.x blocked + {"8.8.8.8", false}, // Global is allowed + {"224.0.0.1", true}, // Multicast } for _, tt := range tests { @@ -108,6 +110,15 @@ func TestIsBlockedFetchIP(t *testing.T) { if !isBlockedFetchIP(nil) { t.Error("nil IP should be blocked") } + + // Test that private IPs can be allowed via environment variable + os.Setenv("PULSE_AI_ALLOW_PRIVATE_IPS", "true") + defer os.Unsetenv("PULSE_AI_ALLOW_PRIVATE_IPS") + + privateIP := net.ParseIP("192.168.1.1") + if isBlockedFetchIP(privateIP) { + t.Error("Private IP should be allowed when PULSE_AI_ALLOW_PRIVATE_IPS=true") + } } func TestFetchURL_SizeLimit(t *testing.T) { diff --git a/internal/api/alerts_test.go b/internal/api/alerts_test.go index 7cdde9558..881415a87 100644 --- a/internal/api/alerts_test.go +++ b/internal/api/alerts_test.go @@ -311,7 +311,7 @@ func TestAcknowledgeAlertURL_Success(t *testing.T) { mockMonitor.On("SyncAlertState").Return() h := NewAlertHandlers(nil, mockMonitor, nil) - mockManager.On("AcknowledgeAlert", "a/b", "admin").Return(nil).Once() + mockManager.On("AcknowledgeAlert", "a/b", testifymock.Anything).Return(nil).Once() req := httptest.NewRequest("POST", "/api/alerts/a%2Fb/acknowledge", nil) w := httptest.NewRecorder() @@ -377,8 +377,8 @@ func TestBulkAcknowledgeAlerts(t *testing.T) { mockMonitor.On("SyncAlertState").Return() h := NewAlertHandlers(nil, mockMonitor, nil) - mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil) - mockManager.On("AcknowledgeAlert", "a2", "admin").Return(fmt.Errorf("error")) + mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(nil) + mockManager.On("AcknowledgeAlert", "a2", testifymock.Anything).Return(fmt.Errorf("error")) body := `{"alertIds": ["a1", "a2"], "user": "admin"}` req := httptest.NewRequest("POST", "/api/alerts/bulk/acknowledge", strings.NewReader(body)) @@ -431,7 +431,7 @@ func TestHandleAlerts(t *testing.T) { mockMonitor.On("SyncAlertState").Return() }}, {"POST", "/api/alerts/acknowledge", func() { - mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil).Once() + mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(nil).Once() mockMonitor.On("SyncAlertState").Return() }}, {"POST", "/api/alerts/unacknowledge", func() { @@ -443,7 +443,7 @@ func TestHandleAlerts(t *testing.T) { mockMonitor.On("SyncAlertState").Return() }}, {"POST", "/api/alerts/a1/acknowledge", func() { - mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil).Once() + mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(nil).Once() mockMonitor.On("SyncAlertState").Return() }}, {"POST", "/api/alerts/a1/unacknowledge", func() { @@ -508,7 +508,7 @@ func TestAcknowledgeAlertByBody_Success(t *testing.T) { mockMonitor.On("SyncAlertState").Return() h := NewAlertHandlers(nil, mockMonitor, nil) - mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil) + mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(nil) body := `{"id": "a1", "user": "admin"}` req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(body)) @@ -580,7 +580,7 @@ func TestAlertHandlers_ErrorCases(t *testing.T) { }) t.Run("AcknowledgeAlertByBody_ManagerError", func(t *testing.T) { - mockManager.On("AcknowledgeAlert", "a1", "admin").Return(fmt.Errorf("error")).Once() + mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(fmt.Errorf("error")).Once() req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(`{"id": "a1", "user": "admin"}`)) w := httptest.NewRecorder() h.AcknowledgeAlertByBody(w, req) @@ -730,7 +730,7 @@ func TestAlertHandlers_ErrorCases(t *testing.T) { }) t.Run("AcknowledgeAlert_Error", func(t *testing.T) { - mockManager.On("AcknowledgeAlert", "a1", "admin").Return(errors.New("not found")).Once() + mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(errors.New("not found")).Once() req := httptest.NewRequest("POST", "/api/alerts/a1/acknowledge", nil) w := httptest.NewRecorder() h.AcknowledgeAlert(w, req) diff --git a/internal/api/router.go b/internal/api/router.go index 8c86bf2cc..f379b02c2 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -1055,6 +1055,20 @@ func (r *Router) setupRoutes() { } } + // SECURITY: Check settings:read scope for API token auth + if hasValidAPIToken && token != "" { + record, _ := r.config.ValidateAPIToken(token) + if record != nil && !record.HasScope(config.ScopeSettingsRead) { + log.Warn(). + Str("ip", req.RemoteAddr). + Str("path", req.URL.Path). + Str("token_id", record.ID). + Msg("API token missing settings:read scope for export") + http.Error(w, "API token missing required scope: settings:read", http.StatusForbidden) + return + } + } + // Log successful export attempt log.Info(). Str("ip", req.RemoteAddr). @@ -1155,6 +1169,20 @@ func (r *Router) setupRoutes() { } } + // SECURITY: Check settings:write scope for API token auth + if hasValidAPIToken && token != "" { + record, _ := r.config.ValidateAPIToken(token) + if record != nil && !record.HasScope(config.ScopeSettingsWrite) { + log.Warn(). + Str("ip", req.RemoteAddr). + Str("path", req.URL.Path). + Str("token_id", record.ID). + Msg("API token missing settings:write scope for import") + http.Error(w, "API token missing required scope: settings:write", http.StatusForbidden) + return + } + } + // Log successful import attempt log.Info(). Str("ip", req.RemoteAddr). @@ -1255,15 +1283,22 @@ func (r *Router) setupRoutes() { // Agent execution server for AI tool use r.agentExecServer = agentexec.NewServer(func(token string) bool { - // Validate agent tokens using the API tokens system + // Validate agent tokens using the API tokens system with scope check if r.config == nil { return false } - // First check the new API tokens system - if _, ok := r.config.ValidateAPIToken(token); ok { + // Check the new API tokens system with scope validation + if record, ok := r.config.ValidateAPIToken(token); ok { + // SECURITY: Require agent:exec scope for WebSocket connections + if !record.HasScope(config.ScopeAgentExec) { + log.Warn(). + Str("token_id", record.ID). + Msg("Agent exec token missing required scope: agent:exec") + return false + } return true } - // Fall back to legacy single token if set + // Fall back to legacy single token if set (legacy tokens have wildcard access) if r.config.APIToken != "" { return auth.CompareAPIToken(token, r.config.APIToken) } @@ -1344,12 +1379,12 @@ func (r *Router) setupRoutes() { r.mux.HandleFunc("/api/ai/test", RequirePermission(r.config, r.authorizer, auth.ActionWrite, auth.ResourceSettings, RequireScope(config.ScopeSettingsWrite, r.aiSettingsHandler.HandleTestAIConnection))) r.mux.HandleFunc("/api/ai/test/{provider}", RequirePermission(r.config, r.authorizer, auth.ActionWrite, auth.ResourceSettings, RequireScope(config.ScopeSettingsWrite, r.aiSettingsHandler.HandleTestProvider))) r.mux.HandleFunc("/api/ai/models", RequireAuth(r.config, r.aiSettingsHandler.HandleListModels)) - r.mux.HandleFunc("/api/ai/execute", RequireAuth(r.config, r.aiSettingsHandler.HandleExecute)) - r.mux.HandleFunc("/api/ai/execute/stream", RequireAuth(r.config, r.aiSettingsHandler.HandleExecuteStream)) - r.mux.HandleFunc("/api/ai/kubernetes/analyze", RequireAuth(r.config, RequireLicenseFeature(r.licenseHandlers, license.FeatureKubernetesAI, r.aiSettingsHandler.HandleAnalyzeKubernetesCluster))) - r.mux.HandleFunc("/api/ai/investigate-alert", RequireAuth(r.config, RequireLicenseFeature(r.licenseHandlers, license.FeatureAIAlerts, r.aiSettingsHandler.HandleInvestigateAlert))) + r.mux.HandleFunc("/api/ai/execute", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, r.aiSettingsHandler.HandleExecute))) + r.mux.HandleFunc("/api/ai/execute/stream", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, r.aiSettingsHandler.HandleExecuteStream))) + r.mux.HandleFunc("/api/ai/kubernetes/analyze", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, RequireLicenseFeature(r.licenseHandlers, license.FeatureKubernetesAI, r.aiSettingsHandler.HandleAnalyzeKubernetesCluster)))) + r.mux.HandleFunc("/api/ai/investigate-alert", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, RequireLicenseFeature(r.licenseHandlers, license.FeatureAIAlerts, r.aiSettingsHandler.HandleInvestigateAlert)))) - r.mux.HandleFunc("/api/ai/run-command", RequireAuth(r.config, r.aiSettingsHandler.HandleRunCommand)) + r.mux.HandleFunc("/api/ai/run-command", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, r.aiSettingsHandler.HandleRunCommand))) r.mux.HandleFunc("/api/ai/knowledge", RequireAuth(r.config, r.aiSettingsHandler.HandleGetGuestKnowledge)) r.mux.HandleFunc("/api/ai/knowledge/save", RequireAuth(r.config, r.aiSettingsHandler.HandleSaveGuestNote)) r.mux.HandleFunc("/api/ai/knowledge/delete", RequireAuth(r.config, r.aiSettingsHandler.HandleDeleteGuestNote)) @@ -1470,8 +1505,8 @@ func (r *Router) setupRoutes() { })) r.mux.HandleFunc("/api/ai/remediation/plan", RequireAuth(r.config, r.aiSettingsHandler.HandleGetRemediationPlan)) r.mux.HandleFunc("/api/ai/remediation/approve", RequireAuth(r.config, r.aiSettingsHandler.HandleApproveRemediationPlan)) - r.mux.HandleFunc("/api/ai/remediation/execute", RequireAuth(r.config, r.aiSettingsHandler.HandleExecuteRemediationPlan)) - r.mux.HandleFunc("/api/ai/remediation/rollback", RequireAuth(r.config, r.aiSettingsHandler.HandleRollbackRemediationPlan)) + r.mux.HandleFunc("/api/ai/remediation/execute", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, r.aiSettingsHandler.HandleExecuteRemediationPlan))) + r.mux.HandleFunc("/api/ai/remediation/rollback", RequireAdmin(r.config, RequireScope(config.ScopeAIExecute, r.aiSettingsHandler.HandleRollbackRemediationPlan))) r.mux.HandleFunc("/api/ai/circuit/status", RequireAuth(r.config, r.aiSettingsHandler.HandleGetCircuitBreakerStatus)) // Phase 7: Incident Recording API @@ -4035,6 +4070,44 @@ func (r *Router) handleChangePassword(w http.ResponseWriter, req *http.Request) return } + // SECURITY: Require authentication before allowing password change attempts + // This prevents brute-force attacks on the current password + if !CheckAuth(r.config, w, req) { + log.Warn(). + Str("ip", req.RemoteAddr). + Str("path", req.URL.Path). + Msg("Unauthenticated password change attempt blocked") + // CheckAuth already wrote the error response + return + } + + // Apply rate limiting to password change attempts to prevent brute-force + clientIP := GetClientIP(req) + if !authLimiter.Allow(clientIP) { + log.Warn(). + Str("ip", clientIP). + Msg("Rate limit exceeded for password change") + writeErrorResponse(w, http.StatusTooManyRequests, "rate_limited", + "Too many password change attempts. Please try again later.", nil) + return + } + + // Check lockout status for the client IP + _, lockedUntil, isLocked := GetLockoutInfo(clientIP) + if isLocked { + remainingMinutes := int(time.Until(lockedUntil).Minutes()) + if remainingMinutes < 1 { + remainingMinutes = 1 + } + log.Warn(). + Str("ip", clientIP). + Time("locked_until", lockedUntil). + Msg("Password change blocked - IP locked out") + writeErrorResponse(w, http.StatusForbidden, "locked_out", + fmt.Sprintf("Too many failed attempts. Try again in %d minutes.", remainingMinutes), nil) + return + } + // Check if using proxy auth and if so, verify admin status if r.config.ProxyAuthSecret != "" { if valid, username, isAdmin := CheckProxyAuth(r.config, req); valid { @@ -4109,6 +4182,7 @@ func (r *Router) handleChangePassword(w http.ResponseWriter, req *http.Request) Str("ip", req.RemoteAddr). Str("username", username). Msg("Failed password change attempt - incorrect current password in auth header") + RecordFailedLogin(clientIP) writeErrorResponse(w, http.StatusUnauthorized, "unauthorized", "Current password is incorrect", nil) return @@ -4128,6 +4202,7 @@ func (r *Router) handleChangePassword(w http.ResponseWriter, req *http.Request) Str("ip", req.RemoteAddr). Str("username", username). Msg("Failed password change attempt - incorrect current password") + RecordFailedLogin(clientIP) writeErrorResponse(w, http.StatusUnauthorized, "unauthorized", "Current password is incorrect", nil) return diff --git a/internal/api/security.go b/internal/api/security.go index 08ff61227..f74eaa2d9 100644 --- a/internal/api/security.go +++ b/internal/api/security.go @@ -585,3 +585,18 @@ func InvalidateUserSessions(user string) { Int("sessions_invalidated", len(sessionIDs)). Msg("Invalidated all user sessions") } + +// UntrackUserSession removes a single session from a user's session list +// (used for single session logout, not password change which clears all) +func UntrackUserSession(user, sessionID string) { + sessionsMu.Lock() + defer sessionsMu.Unlock() + + sessions := allSessions[user] + for i, sid := range sessions { + if sid == sessionID { + allSessions[user] = append(sessions[:i], sessions[i+1:]...) + break + } + } +} diff --git a/internal/config/api_tokens.go b/internal/config/api_tokens.go index 4550f0ebd..ce03f885f 100644 --- a/internal/config/api_tokens.go +++ b/internal/config/api_tokens.go @@ -23,6 +23,8 @@ const ( ScopeHostManage = "host-agent:manage" ScopeSettingsRead = "settings:read" ScopeSettingsWrite = "settings:write" + ScopeAIExecute = "ai:execute" // Allows executing AI commands and remediation plans + ScopeAgentExec = "agent:exec" // Allows agent execution WebSocket connections ) // AllKnownScopes enumerates scopes recognized by the backend (excluding the wildcard sentinel). @@ -38,6 +40,8 @@ var AllKnownScopes = []string{ ScopeHostManage, ScopeSettingsRead, ScopeSettingsWrite, + ScopeAIExecute, + ScopeAgentExec, } var scopeLookup = func() map[string]struct{} {