Fix security vulnerabilities and critical bugs

- Fix WebSocket CORS bypass by strictly verifying origin
- Fix OIDC refresh token persistence by encrypting at rest
- Fix grouped webhook data mutation by cloning alerts
- Fix host agent uninstall authorization and config fetch logic
- Fix notification queue recovery for stuck sending items
- Fix ignored update history limit parameter
- Fix ineffective break statement in WebSocket write pump
This commit is contained in:
rcourtman
2026-02-03 17:16:27 +00:00
parent c7f4030c29
commit b2639ed5a5
6 changed files with 58 additions and 20 deletions

View File

@@ -311,7 +311,9 @@ func (h *HostAgentHandlers) resolveConfigHost(ctx context.Context, hostID string
for _, candidate := range state.Hosts {
if candidate.TokenID != "" && candidate.TokenID == record.ID {
return candidate, true
if candidate.ID == hostID {
return candidate, true
}
}
}
@@ -527,6 +529,11 @@ func (h *HostAgentHandlers) HandleUninstall(w http.ResponseWriter, r *http.Reque
log.Info().Str("hostId", hostID).Msg("Received unregistration request from agent uninstaller")
// Ensure the token can manage this specific host
if !h.ensureHostTokenMatch(w, r, hostID) {
return
}
// Remove the host from state
_, err := h.getMonitor(r.Context()).RemoveHostAgent(hostID)
if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"sync"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/crypto"
"github.com/rs/zerolog/log"
)
@@ -19,6 +20,7 @@ type SessionStore struct {
dataPath string
saveTicker *time.Ticker
stopChan chan bool
crypto *crypto.CryptoManager
}
func sessionHash(token string) string {
@@ -68,10 +70,16 @@ type SessionData struct {
// NewSessionStore creates a new persistent session store
func NewSessionStore(dataPath string) *SessionStore {
cm, err := crypto.NewCryptoManagerAt(dataPath)
if err != nil {
log.Error().Err(err).Msg("Failed to initialize crypto manager for session store")
}
store := &SessionStore{
sessions: make(map[string]*SessionData),
dataPath: dataPath,
stopChan: make(chan bool),
crypto: cm,
}
// Load existing sessions from disk
@@ -340,6 +348,18 @@ func (s *SessionStore) saveUnsafe() {
// Marshal sessions
persisted := make([]sessionPersisted, 0, len(s.sessions))
for key, session := range s.sessions {
refreshToken := session.OIDCRefreshToken
// Encrypt refresh token if crypto is available and token exists
if refreshToken != "" && s.crypto != nil {
if encrypted, err := s.crypto.EncryptString(refreshToken); err == nil {
refreshToken = encrypted
} else {
log.Error().Err(err).Msg("Failed to encrypt refresh token")
// Don't persist if encryption fails to prevent leak
refreshToken = ""
}
}
persisted = append(persisted, sessionPersisted{
Key: key,
Username: session.Username,
@@ -348,7 +368,7 @@ func (s *SessionStore) saveUnsafe() {
UserAgent: session.UserAgent,
IP: session.IP,
OriginalDuration: session.OriginalDuration,
OIDCRefreshToken: session.OIDCRefreshToken,
OIDCRefreshToken: refreshToken,
OIDCAccessTokenExp: session.OIDCAccessTokenExp,
OIDCIssuer: session.OIDCIssuer,
OIDCClientID: session.OIDCClientID,
@@ -401,6 +421,15 @@ func (s *SessionStore) load() {
if now.After(entry.ExpiresAt) {
continue
}
refreshToken := entry.OIDCRefreshToken
// Decrypt refresh token if needed (handles migration from plaintext)
if refreshToken != "" && s.crypto != nil {
if decrypted, err := s.crypto.DecryptString(refreshToken); err == nil {
refreshToken = decrypted
}
// If decryption fails, assume it's legacy plaintext and leave as is
}
s.sessions[entry.Key] = &SessionData{
Username: entry.Username,
ExpiresAt: entry.ExpiresAt,
@@ -408,7 +437,7 @@ func (s *SessionStore) load() {
UserAgent: entry.UserAgent,
IP: entry.IP,
OriginalDuration: entry.OriginalDuration,
OIDCRefreshToken: entry.OIDCRefreshToken,
OIDCRefreshToken: refreshToken,
OIDCAccessTokenExp: entry.OIDCAccessTokenExp,
OIDCIssuer: entry.OIDCIssuer,
OIDCClientID: entry.OIDCClientID,

View File

@@ -316,8 +316,10 @@ func (h *UpdateHandlers) HandleListUpdateHistory(w http.ResponseWriter, r *http.
}
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
// Parse limit (simple implementation)
filter.Limit = 50
var limit int
if _, err := fmt.Sscanf(limitStr, "%d", &limit); err == nil && limit > 0 {
filter.Limit = limit
}
}
if status := r.URL.Query().Get("status"); status != "" {

View File

@@ -1533,7 +1533,11 @@ func (n *NotificationManager) sendGroupedWebhook(webhook WebhookConfig, alertLis
return fmt.Errorf("no alerts to send")
}
primaryAlert := alertList[0]
// Create a shallow copy of the primary alert to avoid mutating the original memory
// when we modify the message for grouped summaries.
originalPrimary := alertList[0]
alertCopy := *originalPrimary
primaryAlert := &alertCopy
customFields := convertWebhookCustomFields(webhook.CustomFields)
var templateData WebhookPayloadData

View File

@@ -108,6 +108,11 @@ func NewNotificationQueue(dataDir string) (*NotificationQueue, error) {
return nil, fmt.Errorf("failed to initialize schema: %w", err)
}
// Reset any stuck "sending" items to "pending" (crash recovery)
if _, err := nq.db.Exec(`UPDATE notification_queue SET status = 'pending' WHERE status = 'sending'`); err != nil {
log.Error().Err(err).Msg("Failed to recover stuck sending notifications")
}
// Start background processors
nq.wg.Add(2)
go nq.processQueue()

View File

@@ -95,23 +95,13 @@ func (h *Hub) checkOrigin(r *http.Request) bool {
allowedOrigins := h.allowedOrigins
h.mu.RUnlock()
// Determine the actual origin based on proxy headers
// Determine the actual origin
scheme := "http"
host := r.Host
// Check if we're behind a reverse proxy
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
scheme = normalizeForwardedProto(forwardedProto, scheme)
} else if forwardedScheme := r.Header.Get("X-Forwarded-Scheme"); forwardedScheme != "" {
scheme = normalizeForwardedProto(forwardedScheme, scheme)
} else if r.TLS != nil {
if r.TLS != nil {
scheme = "https"
}
// Use X-Forwarded-Host if available (for reverse proxy scenarios)
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
host := r.Host
requestOrigin := scheme + "://" + host
@@ -1104,13 +1094,14 @@ func (c *Client) writePump() {
// Send any queued messages
n := len(c.send)
flushLoop:
for i := 0; i < n; i++ {
select {
case msg := <-c.send:
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
log.Warn().Err(err).Str("client", c.id).Int("msgSize", len(msg)).Msg("Failed to flush queued message")
// Don't disconnect on queued message failure, just break the flush loop
break
break flushLoop
}
default:
// No more messages