diff --git a/internal/api/host_agents.go b/internal/api/host_agents.go index 6587c226e..9535b8b34 100644 --- a/internal/api/host_agents.go +++ b/internal/api/host_agents.go @@ -311,7 +311,9 @@ func (h *HostAgentHandlers) resolveConfigHost(ctx context.Context, hostID string for _, candidate := range state.Hosts { if candidate.TokenID != "" && candidate.TokenID == record.ID { - return candidate, true + if candidate.ID == hostID { + return candidate, true + } } } @@ -527,6 +529,11 @@ func (h *HostAgentHandlers) HandleUninstall(w http.ResponseWriter, r *http.Reque log.Info().Str("hostId", hostID).Msg("Received unregistration request from agent uninstaller") + // Ensure the token can manage this specific host + if !h.ensureHostTokenMatch(w, r, hostID) { + return + } + // Remove the host from state _, err := h.getMonitor(r.Context()).RemoveHostAgent(hostID) if err != nil { diff --git a/internal/api/session_store.go b/internal/api/session_store.go index e9478e7c1..b45dd12a9 100644 --- a/internal/api/session_store.go +++ b/internal/api/session_store.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/rcourtman/pulse-go-rewrite/internal/crypto" "github.com/rs/zerolog/log" ) @@ -19,6 +20,7 @@ type SessionStore struct { dataPath string saveTicker *time.Ticker stopChan chan bool + crypto *crypto.CryptoManager } func sessionHash(token string) string { @@ -68,10 +70,16 @@ type SessionData struct { // NewSessionStore creates a new persistent session store func NewSessionStore(dataPath string) *SessionStore { + cm, err := crypto.NewCryptoManagerAt(dataPath) + if err != nil { + log.Error().Err(err).Msg("Failed to initialize crypto manager for session store") + } + store := &SessionStore{ sessions: make(map[string]*SessionData), dataPath: dataPath, stopChan: make(chan bool), + crypto: cm, } // Load existing sessions from disk @@ -340,6 +348,18 @@ func (s *SessionStore) saveUnsafe() { // Marshal sessions persisted := make([]sessionPersisted, 0, len(s.sessions)) for key, session := range s.sessions { + refreshToken := session.OIDCRefreshToken + // Encrypt refresh token if crypto is available and token exists + if refreshToken != "" && s.crypto != nil { + if encrypted, err := s.crypto.EncryptString(refreshToken); err == nil { + refreshToken = encrypted + } else { + log.Error().Err(err).Msg("Failed to encrypt refresh token") + // Don't persist if encryption fails to prevent leak + refreshToken = "" + } + } + persisted = append(persisted, sessionPersisted{ Key: key, Username: session.Username, @@ -348,7 +368,7 @@ func (s *SessionStore) saveUnsafe() { UserAgent: session.UserAgent, IP: session.IP, OriginalDuration: session.OriginalDuration, - OIDCRefreshToken: session.OIDCRefreshToken, + OIDCRefreshToken: refreshToken, OIDCAccessTokenExp: session.OIDCAccessTokenExp, OIDCIssuer: session.OIDCIssuer, OIDCClientID: session.OIDCClientID, @@ -401,6 +421,15 @@ func (s *SessionStore) load() { if now.After(entry.ExpiresAt) { continue } + refreshToken := entry.OIDCRefreshToken + // Decrypt refresh token if needed (handles migration from plaintext) + if refreshToken != "" && s.crypto != nil { + if decrypted, err := s.crypto.DecryptString(refreshToken); err == nil { + refreshToken = decrypted + } + // If decryption fails, assume it's legacy plaintext and leave as is + } + s.sessions[entry.Key] = &SessionData{ Username: entry.Username, ExpiresAt: entry.ExpiresAt, @@ -408,7 +437,7 @@ func (s *SessionStore) load() { UserAgent: entry.UserAgent, IP: entry.IP, OriginalDuration: entry.OriginalDuration, - OIDCRefreshToken: entry.OIDCRefreshToken, + OIDCRefreshToken: refreshToken, OIDCAccessTokenExp: entry.OIDCAccessTokenExp, OIDCIssuer: entry.OIDCIssuer, OIDCClientID: entry.OIDCClientID, diff --git a/internal/api/updates.go b/internal/api/updates.go index 683ff1fe5..2289aba70 100644 --- a/internal/api/updates.go +++ b/internal/api/updates.go @@ -316,8 +316,10 @@ func (h *UpdateHandlers) HandleListUpdateHistory(w http.ResponseWriter, r *http. } if limitStr := r.URL.Query().Get("limit"); limitStr != "" { - // Parse limit (simple implementation) - filter.Limit = 50 + var limit int + if _, err := fmt.Sscanf(limitStr, "%d", &limit); err == nil && limit > 0 { + filter.Limit = limit + } } if status := r.URL.Query().Get("status"); status != "" { diff --git a/internal/notifications/notifications.go b/internal/notifications/notifications.go index 70df51e52..1e3c5bd03 100644 --- a/internal/notifications/notifications.go +++ b/internal/notifications/notifications.go @@ -1533,7 +1533,11 @@ func (n *NotificationManager) sendGroupedWebhook(webhook WebhookConfig, alertLis return fmt.Errorf("no alerts to send") } - primaryAlert := alertList[0] + // Create a shallow copy of the primary alert to avoid mutating the original memory + // when we modify the message for grouped summaries. + originalPrimary := alertList[0] + alertCopy := *originalPrimary + primaryAlert := &alertCopy customFields := convertWebhookCustomFields(webhook.CustomFields) var templateData WebhookPayloadData diff --git a/internal/notifications/queue.go b/internal/notifications/queue.go index 4bd892863..772bc4618 100644 --- a/internal/notifications/queue.go +++ b/internal/notifications/queue.go @@ -108,6 +108,11 @@ func NewNotificationQueue(dataDir string) (*NotificationQueue, error) { return nil, fmt.Errorf("failed to initialize schema: %w", err) } + // Reset any stuck "sending" items to "pending" (crash recovery) + if _, err := nq.db.Exec(`UPDATE notification_queue SET status = 'pending' WHERE status = 'sending'`); err != nil { + log.Error().Err(err).Msg("Failed to recover stuck sending notifications") + } + // Start background processors nq.wg.Add(2) go nq.processQueue() diff --git a/internal/websocket/hub.go b/internal/websocket/hub.go index 4834df078..a3600a3c6 100644 --- a/internal/websocket/hub.go +++ b/internal/websocket/hub.go @@ -95,23 +95,13 @@ func (h *Hub) checkOrigin(r *http.Request) bool { allowedOrigins := h.allowedOrigins h.mu.RUnlock() - // Determine the actual origin based on proxy headers + // Determine the actual origin scheme := "http" - host := r.Host - - // Check if we're behind a reverse proxy - if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { - scheme = normalizeForwardedProto(forwardedProto, scheme) - } else if forwardedScheme := r.Header.Get("X-Forwarded-Scheme"); forwardedScheme != "" { - scheme = normalizeForwardedProto(forwardedScheme, scheme) - } else if r.TLS != nil { + if r.TLS != nil { scheme = "https" } - // Use X-Forwarded-Host if available (for reverse proxy scenarios) - if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { - host = forwardedHost - } + host := r.Host requestOrigin := scheme + "://" + host @@ -1104,13 +1094,14 @@ func (c *Client) writePump() { // Send any queued messages n := len(c.send) + flushLoop: for i := 0; i < n; i++ { select { case msg := <-c.send: if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil { log.Warn().Err(err).Str("client", c.id).Int("msgSize", len(msg)).Msg("Failed to flush queued message") // Don't disconnect on queued message failure, just break the flush loop - break + break flushLoop } default: // No more messages