Files
Pulse/internal/websocket/hub_test.go
rcourtman a9ed380718 fix(websocket): respect X-Forwarded headers in same-origin check
- Use X-Forwarded-Proto/X-Forwarded-Scheme for scheme detection
- Use X-Forwarded-Host for host matching behind reverse proxies
- Update tests with remoteAddr for CSWSH protection validation
2026-02-03 21:45:39 +00:00

879 lines
23 KiB
Go

package websocket
import (
"math"
"net/http"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
)
func TestIsValidPrivateOrigin(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
// Localhost variations
{"localhost", "localhost", true},
{"ipv4 loopback", "127.0.0.1", true},
{"ipv6 loopback", "::1", true},
// Private IPv4 ranges
{"10.x.x.x private", "10.0.0.1", true},
{"10.x.x.x edge", "10.255.255.255", true},
{"172.16.x.x private", "172.16.0.1", true},
{"172.31.x.x private", "172.31.255.255", true},
{"192.168.x.x private", "192.168.1.1", true},
{"192.168.x.x edge", "192.168.255.255", true},
// Local domain suffixes
{"hostname.local", "myhost.local", true},
{"hostname.lan", "myhost.lan", true},
{"subdomain.hostname.local", "sub.myhost.local", true},
{"too many subdomains .local", "a.b.c.d.local", false},
// Public IPs (should reject)
{"public IP 8.8.8.8", "8.8.8.8", false},
{"public IP 1.1.1.1", "1.1.1.1", false},
{"public IP 203.0.113.1", "203.0.113.1", false},
// Public domains (should reject)
{"example.com", "example.com", false},
{"google.com", "google.com", false},
{"malicious.attacker.com", "malicious.attacker.com", false},
// Edge cases
{"empty string", "", false},
{"just dot", ".", false},
{"numbers only", "12345", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := isValidPrivateOrigin(tc.host)
if result != tc.expected {
t.Errorf("isValidPrivateOrigin(%q) = %v, want %v", tc.host, result, tc.expected)
}
})
}
}
func TestNormalizeForwardedProto(t *testing.T) {
tests := []struct {
name string
proto string
fallback string
expected string
}{
// Empty proto returns fallback
{"empty proto returns fallback", "", "http", "http"},
{"empty proto returns https fallback", "", "https", "https"},
// Standard HTTP schemes
{"http passthrough", "http", "https", "http"},
{"https passthrough", "https", "http", "https"},
{"HTTP uppercase", "HTTP", "http", "http"},
{"HTTPS uppercase", "HTTPS", "http", "https"},
// WebSocket schemes normalized to HTTP
{"ws becomes http", "ws", "https", "http"},
{"wss becomes https", "wss", "http", "https"},
{"WS uppercase", "WS", "https", "http"},
{"WSS uppercase", "WSS", "http", "https"},
// Comma-separated chains (take first)
{"chain wss,https", "wss,https", "http", "https"},
{"chain https,wss", "https,wss", "http", "https"},
{"chain http,wss,https", "http,wss,https", "https", "http"},
// Whitespace handling
{"whitespace trimmed", " https ", "http", "https"},
{"whitespace in chain", " wss , https ", "http", "https"},
// Unknown protos pass through
{"unknown proto", "ftp", "http", "ftp"},
{"unknown empty after trim", " ", "http", "http"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := normalizeForwardedProto(tc.proto, tc.fallback)
if result != tc.expected {
t.Errorf("normalizeForwardedProto(%q, %q) = %q, want %q", tc.proto, tc.fallback, result, tc.expected)
}
})
}
}
func TestSanitizeValue(t *testing.T) {
tests := []struct {
name string
input interface{}
expected interface{}
}{
// Normal values pass through
{"normal float64", float64(42.5), float64(42.5)},
{"zero float64", float64(0), float64(0)},
{"negative float64", float64(-100.5), float64(-100.5)},
{"string", "hello", "hello"},
{"int via float64", float64(100), float64(100)},
{"bool true", true, true},
{"bool false", false, false},
{"nil", nil, nil},
// NaN becomes nil
{"NaN float64", math.NaN(), nil},
// Inf becomes nil
{"positive Inf", math.Inf(1), nil},
{"negative Inf", math.Inf(-1), nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := sanitizeValue(tc.input)
// Special handling for NaN comparison
if tc.expected == nil {
if result != nil {
t.Errorf("sanitizeValue() = %v, want nil", result)
}
return
}
if result != tc.expected {
t.Errorf("sanitizeValue() = %v, want %v", result, tc.expected)
}
})
}
}
func TestSanitizeValue_Map(t *testing.T) {
input := map[string]interface{}{
"normal": float64(42.5),
"nan": math.NaN(),
"inf": math.Inf(1),
"negInf": math.Inf(-1),
"string": "test",
"bool": true,
"nilValue": nil,
}
result := sanitizeValue(input)
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Fatalf("expected map[string]interface{}, got %T", result)
}
// Check normal value preserved
if resultMap["normal"] != float64(42.5) {
t.Errorf("normal value = %v, want 42.5", resultMap["normal"])
}
// Check NaN/Inf became nil
if resultMap["nan"] != nil {
t.Errorf("nan value = %v, want nil", resultMap["nan"])
}
if resultMap["inf"] != nil {
t.Errorf("inf value = %v, want nil", resultMap["inf"])
}
if resultMap["negInf"] != nil {
t.Errorf("negInf value = %v, want nil", resultMap["negInf"])
}
// Check other types preserved
if resultMap["string"] != "test" {
t.Errorf("string value = %v, want 'test'", resultMap["string"])
}
if resultMap["bool"] != true {
t.Errorf("bool value = %v, want true", resultMap["bool"])
}
}
func TestSanitizeValue_Slice(t *testing.T) {
input := []interface{}{
float64(1.0),
math.NaN(),
float64(2.0),
math.Inf(1),
"string",
}
result := sanitizeValue(input)
resultSlice, ok := result.([]interface{})
if !ok {
t.Fatalf("expected []interface{}, got %T", result)
}
if len(resultSlice) != 5 {
t.Fatalf("slice length = %d, want 5", len(resultSlice))
}
if resultSlice[0] != float64(1.0) {
t.Errorf("resultSlice[0] = %v, want 1.0", resultSlice[0])
}
if resultSlice[1] != nil {
t.Errorf("resultSlice[1] (NaN) = %v, want nil", resultSlice[1])
}
if resultSlice[2] != float64(2.0) {
t.Errorf("resultSlice[2] = %v, want 2.0", resultSlice[2])
}
if resultSlice[3] != nil {
t.Errorf("resultSlice[3] (Inf) = %v, want nil", resultSlice[3])
}
if resultSlice[4] != "string" {
t.Errorf("resultSlice[4] = %v, want 'string'", resultSlice[4])
}
}
func TestSanitizeValue_NestedMap(t *testing.T) {
input := map[string]interface{}{
"outer": map[string]interface{}{
"inner": map[string]interface{}{
"nan": math.NaN(),
"ok": float64(42),
},
},
}
result := sanitizeValue(input)
resultMap := result.(map[string]interface{})
outer := resultMap["outer"].(map[string]interface{})
inner := outer["inner"].(map[string]interface{})
if inner["nan"] != nil {
t.Errorf("nested nan = %v, want nil", inner["nan"])
}
if inner["ok"] != float64(42) {
t.Errorf("nested ok = %v, want 42", inner["ok"])
}
}
func TestCloneMetadata(t *testing.T) {
t.Run("nil input returns nil", func(t *testing.T) {
result := cloneMetadata(nil)
if result != nil {
t.Errorf("cloneMetadata(nil) = %v, want nil", result)
}
})
t.Run("empty map returns empty map", func(t *testing.T) {
input := map[string]interface{}{}
result := cloneMetadata(input)
if result == nil {
t.Fatal("cloneMetadata(empty) returned nil")
}
if len(result) != 0 {
t.Errorf("cloneMetadata(empty) has length %d, want 0", len(result))
}
})
t.Run("clones simple values", func(t *testing.T) {
input := map[string]interface{}{
"string": "hello",
"int": 42,
"float": 3.14,
"bool": true,
}
result := cloneMetadata(input)
// Verify values copied
if result["string"] != "hello" {
t.Errorf("string = %v, want 'hello'", result["string"])
}
if result["int"] != 42 {
t.Errorf("int = %v, want 42", result["int"])
}
if result["float"] != 3.14 {
t.Errorf("float = %v, want 3.14", result["float"])
}
if result["bool"] != true {
t.Errorf("bool = %v, want true", result["bool"])
}
// Verify it's a different map
input["string"] = "modified"
if result["string"] == "modified" {
t.Error("cloned map should not be affected by original modifications")
}
})
t.Run("deep clones nested maps", func(t *testing.T) {
nested := map[string]interface{}{"key": "value"}
input := map[string]interface{}{
"nested": nested,
}
result := cloneMetadata(input)
// Modify original nested map
nested["key"] = "modified"
// Cloned nested map should be unaffected
resultNested := result["nested"].(map[string]interface{})
if resultNested["key"] != "value" {
t.Errorf("nested key = %v, want 'value' (should not be affected by original)", resultNested["key"])
}
})
t.Run("deep clones slices", func(t *testing.T) {
slice := []string{"a", "b", "c"}
input := map[string]interface{}{
"slice": slice,
}
result := cloneMetadata(input)
// Modify original slice
slice[0] = "modified"
// Cloned slice should be unaffected
resultSlice := result["slice"].([]string)
if resultSlice[0] != "a" {
t.Errorf("slice[0] = %v, want 'a' (should not be affected by original)", resultSlice[0])
}
})
}
func TestCloneMetadataValue(t *testing.T) {
t.Run("clones map[string]interface{}", func(t *testing.T) {
input := map[string]interface{}{"key": "value"}
result := cloneMetadataValue(input)
resultMap := result.(map[string]interface{})
input["key"] = "modified"
if resultMap["key"] != "value" {
t.Error("cloned map should not be affected by original")
}
})
t.Run("clones map[string]string", func(t *testing.T) {
input := map[string]string{"key": "value"}
result := cloneMetadataValue(input)
// map[string]string gets converted to map[string]interface{}
resultMap := result.(map[string]interface{})
if resultMap["key"] != "value" {
t.Errorf("key = %v, want 'value'", resultMap["key"])
}
})
t.Run("clones []interface{}", func(t *testing.T) {
input := []interface{}{"a", "b", "c"}
result := cloneMetadataValue(input)
resultSlice := result.([]interface{})
input[0] = "modified"
if resultSlice[0] != "a" {
t.Error("cloned slice should not be affected by original")
}
})
t.Run("clones []string", func(t *testing.T) {
input := []string{"a", "b", "c"}
result := cloneMetadataValue(input)
resultSlice := result.([]string)
input[0] = "modified"
if resultSlice[0] != "a" {
t.Error("cloned string slice should not be affected by original")
}
})
t.Run("clones []int", func(t *testing.T) {
input := []int{1, 2, 3}
result := cloneMetadataValue(input)
resultSlice := result.([]int)
input[0] = 999
if resultSlice[0] != 1 {
t.Error("cloned int slice should not be affected by original")
}
})
t.Run("clones []float64", func(t *testing.T) {
input := []float64{1.1, 2.2, 3.3}
result := cloneMetadataValue(input)
resultSlice := result.([]float64)
input[0] = 999.9
if resultSlice[0] != 1.1 {
t.Error("cloned float64 slice should not be affected by original")
}
})
t.Run("primitives returned as-is", func(t *testing.T) {
// Primitives are immutable, so returning as-is is fine
if cloneMetadataValue("string") != "string" {
t.Error("string should pass through")
}
if cloneMetadataValue(42) != 42 {
t.Error("int should pass through")
}
if cloneMetadataValue(3.14) != 3.14 {
t.Error("float should pass through")
}
if cloneMetadataValue(true) != true {
t.Error("bool should pass through")
}
if cloneMetadataValue(nil) != nil {
t.Error("nil should pass through")
}
})
}
func TestCloneAlert(t *testing.T) {
t.Run("nil alert returns empty alert", func(t *testing.T) {
result := cloneAlert(nil)
if result.ID != "" {
t.Error("cloneAlert(nil) should return empty alert")
}
})
t.Run("clones basic fields", func(t *testing.T) {
now := time.Now()
original := &alerts.Alert{
ID: "alert-123",
Type: "cpu",
Level: alerts.AlertLevelWarning,
ResourceID: "vm/100",
Message: "CPU high",
Value: 85.5,
StartTime: now,
}
result := cloneAlert(original)
if result.ID != "alert-123" {
t.Errorf("ID = %v, want alert-123", result.ID)
}
if result.Type != "cpu" {
t.Errorf("Type = %v, want cpu", result.Type)
}
if result.Level != alerts.AlertLevelWarning {
t.Errorf("Level = %v, want AlertLevelWarning", result.Level)
}
if result.ResourceID != "vm/100" {
t.Errorf("ResourceID = %v, want vm/100", result.ResourceID)
}
if result.Message != "CPU high" {
t.Errorf("Message = %v, want 'CPU high'", result.Message)
}
if result.Value != 85.5 {
t.Errorf("Value = %v, want 85.5", result.Value)
}
})
t.Run("deep clones AckTime", func(t *testing.T) {
ackTime := time.Now()
original := &alerts.Alert{
ID: "alert-123",
AckTime: &ackTime,
}
result := cloneAlert(original)
// Verify AckTime is cloned
if result.AckTime == nil {
t.Fatal("AckTime should not be nil")
}
if result.AckTime == original.AckTime {
t.Error("AckTime should be a different pointer")
}
if !result.AckTime.Equal(*original.AckTime) {
t.Error("AckTime values should be equal")
}
})
t.Run("deep clones EscalationTimes", func(t *testing.T) {
original := &alerts.Alert{
ID: "alert-123",
EscalationTimes: []time.Time{time.Now(), time.Now().Add(time.Hour)},
}
result := cloneAlert(original)
if len(result.EscalationTimes) != 2 {
t.Fatalf("EscalationTimes length = %d, want 2", len(result.EscalationTimes))
}
// Modify original
original.EscalationTimes[0] = time.Now().Add(24 * time.Hour)
// Clone should be unaffected
if result.EscalationTimes[0].Equal(original.EscalationTimes[0]) {
t.Error("cloned EscalationTimes should not be affected by original modifications")
}
})
t.Run("deep clones Metadata", func(t *testing.T) {
original := &alerts.Alert{
ID: "alert-123",
Metadata: map[string]interface{}{
"key": "value",
"nested": map[string]interface{}{
"inner": "data",
},
},
}
result := cloneAlert(original)
// Modify original metadata
original.Metadata["key"] = "modified"
// Clone should be unaffected
if result.Metadata["key"] != "value" {
t.Error("cloned Metadata should not be affected by original modifications")
}
})
}
func TestCloneAlertData(t *testing.T) {
t.Run("handles *alerts.Alert", func(t *testing.T) {
original := &alerts.Alert{ID: "alert-123", Message: "test"}
result := cloneAlertData(original)
cloned, ok := result.(alerts.Alert)
if !ok {
t.Fatalf("expected alerts.Alert, got %T", result)
}
if cloned.ID != "alert-123" {
t.Errorf("ID = %v, want alert-123", cloned.ID)
}
})
t.Run("handles alerts.Alert value", func(t *testing.T) {
original := alerts.Alert{ID: "alert-456", Message: "test"}
result := cloneAlertData(original)
cloned, ok := result.(alerts.Alert)
if !ok {
t.Fatalf("expected alerts.Alert, got %T", result)
}
if cloned.ID != "alert-456" {
t.Errorf("ID = %v, want alert-456", cloned.ID)
}
})
t.Run("returns other types unchanged", func(t *testing.T) {
// Strings, maps, etc. that aren't alerts should pass through
result := cloneAlertData("not an alert")
if result != "not an alert" {
t.Errorf("non-alert data should pass through unchanged")
}
})
}
func TestNewHub(t *testing.T) {
stateGetter := func() interface{} {
return map[string]string{"status": "ok"}
}
hub := NewHub(stateGetter)
if hub == nil {
t.Fatal("NewHub returned nil")
}
if hub.clients == nil {
t.Error("clients map should be initialized")
}
if hub.broadcast == nil {
t.Error("broadcast channel should be initialized")
}
if hub.broadcastSeq == nil {
t.Error("broadcastSeq channel should be initialized")
}
if hub.register == nil {
t.Error("register channel should be initialized")
}
if hub.unregister == nil {
t.Error("unregister channel should be initialized")
}
if hub.stopChan == nil {
t.Error("stopChan should be initialized")
}
if hub.getState == nil {
t.Error("getState should be set")
}
if len(hub.allowedOrigins) != 0 {
t.Error("allowedOrigins should be empty by default")
}
if hub.coalesceWindow != 100*time.Millisecond {
t.Errorf("coalesceWindow = %v, want 100ms", hub.coalesceWindow)
}
}
func TestHub_SetAllowedOrigins(t *testing.T) {
hub := NewHub(nil)
origins := []string{"http://localhost:3000", "https://example.com"}
hub.SetAllowedOrigins(origins)
hub.mu.RLock()
defer hub.mu.RUnlock()
if len(hub.allowedOrigins) != 2 {
t.Errorf("allowedOrigins length = %d, want 2", len(hub.allowedOrigins))
}
if hub.allowedOrigins[0] != "http://localhost:3000" {
t.Errorf("allowedOrigins[0] = %v, want http://localhost:3000", hub.allowedOrigins[0])
}
}
func TestHub_SetStateGetter(t *testing.T) {
hub := NewHub(nil)
if hub.getState != nil {
t.Error("getState should be nil initially when passed nil")
}
newGetter := func() interface{} {
return "new state"
}
hub.SetStateGetter(newGetter)
hub.mu.RLock()
defer hub.mu.RUnlock()
if hub.getState == nil {
t.Error("getState should be set after SetStateGetter")
}
}
func TestHub_GetClientCount(t *testing.T) {
hub := NewHub(nil)
if hub.GetClientCount() != 0 {
t.Error("client count should be 0 initially")
}
// Manually add clients (bypassing register channel for unit test)
hub.mu.Lock()
hub.clients[&Client{id: "client-1"}] = true
hub.clients[&Client{id: "client-2"}] = true
hub.mu.Unlock()
if hub.GetClientCount() != 2 {
t.Errorf("client count = %d, want 2", hub.GetClientCount())
}
}
func TestMessage_Fields(t *testing.T) {
msg := Message{
Type: "test",
Data: map[string]string{"key": "value"},
Timestamp: "2024-01-01T00:00:00Z",
}
if msg.Type != "test" {
t.Errorf("Type = %v, want test", msg.Type)
}
if msg.Timestamp != "2024-01-01T00:00:00Z" {
t.Errorf("Timestamp = %v, want 2024-01-01T00:00:00Z", msg.Timestamp)
}
}
func TestHub_CheckOrigin(t *testing.T) {
tests := []struct {
name string
origin string
host string
allowedOrigins []string
forwardedProto string
forwardedHost string
remoteAddr string // Simulated peer IP for CSWSH checks
expected bool
}{
// No origin header - always allowed for non-browser clients
{
name: "no origin header",
origin: "",
host: "localhost:8080",
expected: true,
},
// Same-origin requests
{
name: "same origin http",
origin: "http://localhost:8080",
host: "localhost:8080",
expected: true,
},
{
name: "same origin with forwarded proto https",
origin: "https://example.com",
host: "example.com",
forwardedProto: "https",
expected: true,
},
{
name: "same origin with forwarded host",
origin: "http://proxy.example.com",
host: "backend:8080",
forwardedHost: "proxy.example.com",
expected: true,
},
// Wildcard allowed origins
{
name: "wildcard allows any origin",
origin: "https://evil.com",
host: "localhost:8080",
allowedOrigins: []string{"*"},
expected: true,
},
// Explicit allowed origins
{
name: "explicit allowed origin matches",
origin: "https://app.example.com",
host: "localhost:8080",
allowedOrigins: []string{"https://app.example.com"},
expected: true,
},
{
name: "explicit allowed origin no match",
origin: "https://other.example.com",
host: "localhost:8080",
allowedOrigins: []string{"https://app.example.com"},
expected: false,
},
{
name: "multiple allowed origins - match second",
origin: "https://second.example.com",
host: "localhost:8080",
allowedOrigins: []string{"https://first.example.com", "https://second.example.com"},
expected: true,
},
// Private network fallback (no allowed origins configured)
// Note: remoteAddr must be private for CSWSH protection to allow
{
name: "private IP 192.168.x.x allowed when no origins configured",
origin: "http://192.168.1.100:3000",
host: "localhost:8080",
remoteAddr: "192.168.1.100:54321",
expected: true,
},
{
name: "private IP 10.x.x.x allowed when no origins configured",
origin: "http://10.0.0.50:3000",
host: "localhost:8080",
remoteAddr: "10.0.0.50:54321",
expected: true,
},
{
name: "localhost allowed when no origins configured",
origin: "http://localhost:3000",
host: "localhost:8080",
remoteAddr: "127.0.0.1:54321",
expected: true,
},
{
name: "127.0.0.1 allowed when no origins configured",
origin: "http://127.0.0.1:3000",
host: "localhost:8080",
remoteAddr: "127.0.0.1:54321",
expected: true,
},
{
name: ".local domain allowed when no origins configured",
origin: "http://myserver.local:3000",
host: "localhost:8080",
remoteAddr: "192.168.1.50:54321",
expected: true,
},
{
name: ".lan domain allowed when no origins configured",
origin: "http://myserver.lan:3000",
host: "localhost:8080",
remoteAddr: "192.168.1.50:54321",
expected: true,
},
{
name: "public IP rejected when no origins configured",
origin: "http://8.8.8.8:3000",
host: "localhost:8080",
expected: false,
},
{
name: "public domain rejected when no origins configured",
origin: "https://evil.example.com",
host: "localhost:8080",
expected: false,
},
// HTTPS origin stripping
{
name: "https origin with private IP",
origin: "https://192.168.1.50:443",
host: "localhost:8080",
remoteAddr: "192.168.1.50:54321",
expected: true,
},
// Forwarded proto normalization
{
name: "wss forwarded proto normalized to https",
origin: "https://example.com",
host: "example.com",
forwardedProto: "wss",
expected: true,
},
{
name: "ws forwarded proto normalized to http",
origin: "http://example.com",
host: "example.com",
forwardedProto: "ws",
expected: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
hub := NewHub(nil)
if len(tc.allowedOrigins) > 0 {
hub.SetAllowedOrigins(tc.allowedOrigins)
}
req := &http.Request{
Host: tc.host,
Header: make(http.Header),
RemoteAddr: tc.remoteAddr,
}
if tc.origin != "" {
req.Header.Set("Origin", tc.origin)
}
if tc.forwardedProto != "" {
req.Header.Set("X-Forwarded-Proto", tc.forwardedProto)
}
if tc.forwardedHost != "" {
req.Header.Set("X-Forwarded-Host", tc.forwardedHost)
}
result := hub.checkOrigin(req)
if result != tc.expected {
t.Errorf("checkOrigin() = %v, want %v", result, tc.expected)
}
})
}
}
func TestHub_CheckOrigin_XForwardedScheme(t *testing.T) {
hub := NewHub(nil)
req := &http.Request{
Host: "example.com",
Header: make(http.Header),
}
req.Header.Set("Origin", "https://example.com")
req.Header.Set("X-Forwarded-Scheme", "https")
result := hub.checkOrigin(req)
if !result {
t.Error("checkOrigin should allow same-origin with X-Forwarded-Scheme")
}
}