Files
Pulse/internal/api/rate_limit_config_test.go
rcourtman 7599915b8f refactor(api): remove sensor proxy config from API handlers
- config_handlers.go: remove proxy configuration endpoints
- system_settings.go: remove proxy-related settings
- rate_limit_config.go: update rate limit configuration
- Update related test files
2026-01-21 12:02:46 +00:00

527 lines
15 KiB
Go

package api
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
)
func TestGetRateLimiterForEndpoint(t *testing.T) {
// Ensure rate limiters are initialized
InitializeRateLimiters()
tests := []struct {
name string
path string
method string
wantLimiterPtr **RateLimiter
wantLimiterNm string
}{
// Authentication endpoints
{
name: "login endpoint POST",
path: "/api/login",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints,
wantLimiterNm: "AuthEndpoints",
},
{
name: "login endpoint GET",
path: "/api/login",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints,
wantLimiterNm: "AuthEndpoints",
},
{
name: "logout endpoint",
path: "/api/logout",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints,
wantLimiterNm: "AuthEndpoints",
},
{
name: "change password endpoint",
path: "/api/security/change-password",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints,
wantLimiterNm: "AuthEndpoints",
},
{
name: "generic auth endpoint",
path: "/api/auth/something",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints,
wantLimiterNm: "AuthEndpoints",
},
{
name: "login with uppercase path still matches (lowercased)",
path: "/API/LOGIN",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints,
wantLimiterNm: "AuthEndpoints",
},
// Recovery endpoints
{
name: "recovery endpoint POST",
path: "/api/security/recovery",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.RecoveryEndpoints,
wantLimiterNm: "RecoveryEndpoints",
},
{
name: "recovery endpoint GET",
path: "/api/security/recovery/status",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.RecoveryEndpoints,
wantLimiterNm: "RecoveryEndpoints",
},
// Export/Import endpoints
{
name: "config export endpoint",
path: "/api/config/export",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.ExportEndpoints,
wantLimiterNm: "ExportEndpoints",
},
{
name: "config import endpoint",
path: "/api/config/import",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.ExportEndpoints,
wantLimiterNm: "ExportEndpoints",
},
// Configuration endpoints (write operations)
{
name: "config nodes POST",
path: "/api/config/nodes",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints,
wantLimiterNm: "ConfigEndpoints",
},
{
name: "config nodes PUT",
path: "/api/config/nodes/123",
method: http.MethodPut,
wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints,
wantLimiterNm: "ConfigEndpoints",
},
{
name: "config nodes DELETE",
path: "/api/config/nodes/123",
method: http.MethodDelete,
wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints,
wantLimiterNm: "ConfigEndpoints",
},
{
name: "config system POST",
path: "/api/config/system",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints,
wantLimiterNm: "ConfigEndpoints",
},
{
name: "config webhooks POST",
path: "/api/config/webhooks",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints,
wantLimiterNm: "ConfigEndpoints",
},
{
name: "config alerts POST",
path: "/api/config/alerts",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints,
wantLimiterNm: "ConfigEndpoints",
},
// Configuration endpoints (read operations - higher limits)
{
name: "config nodes GET",
path: "/api/config/nodes",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
{
name: "config system GET",
path: "/api/config/system",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
{
name: "discover GET",
path: "/api/discover",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
{
name: "security status GET",
path: "/api/security/status",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
// Update endpoints
{
name: "updates check GET",
path: "/api/updates",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.UpdateEndpoints,
wantLimiterNm: "UpdateEndpoints",
},
{
name: "updates status GET",
path: "/api/updates/status",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.UpdateEndpoints,
wantLimiterNm: "UpdateEndpoints",
},
// WebSocket endpoints
{
name: "ws endpoint",
path: "/ws",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.WebSocketEndpoints,
wantLimiterNm: "WebSocketEndpoints",
},
{
name: "ws subpath",
path: "/ws/events",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.WebSocketEndpoints,
wantLimiterNm: "WebSocketEndpoints",
},
// Public endpoints
{
name: "health endpoint",
path: "/api/health",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
{
name: "version endpoint",
path: "/api/version",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
{
name: "validate bootstrap token",
path: "/api/security/validate-bootstrap-token",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
{
name: "metrics endpoint",
path: "/metrics",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints,
wantLimiterNm: "PublicEndpoints",
},
// General API (default)
{
name: "guests endpoint",
path: "/api/guests",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.GeneralAPI,
wantLimiterNm: "GeneralAPI",
},
{
name: "nodes endpoint",
path: "/api/nodes",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.GeneralAPI,
wantLimiterNm: "GeneralAPI",
},
{
name: "unknown api endpoint",
path: "/api/unknown/endpoint",
method: http.MethodGet,
wantLimiterPtr: &globalRateLimitConfig.GeneralAPI,
wantLimiterNm: "GeneralAPI",
},
{
name: "containers endpoint POST",
path: "/api/containers/start",
method: http.MethodPost,
wantLimiterPtr: &globalRateLimitConfig.GeneralAPI,
wantLimiterNm: "GeneralAPI",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetRateLimiterForEndpoint(tt.path, tt.method)
want := *tt.wantLimiterPtr
if got != want {
t.Errorf("GetRateLimiterForEndpoint(%q, %q) returned %s limiter, want %s",
tt.path, tt.method, identifyLimiter(got), tt.wantLimiterNm)
}
})
}
}
// identifyLimiter returns a string name for a rate limiter pointer for test output
func identifyLimiter(rl *RateLimiter) string {
if rl == nil {
return "nil"
}
if globalRateLimitConfig == nil {
return "unknown (config nil)"
}
switch rl {
case globalRateLimitConfig.AuthEndpoints:
return "AuthEndpoints"
case globalRateLimitConfig.ConfigEndpoints:
return "ConfigEndpoints"
case globalRateLimitConfig.ExportEndpoints:
return "ExportEndpoints"
case globalRateLimitConfig.RecoveryEndpoints:
return "RecoveryEndpoints"
case globalRateLimitConfig.UpdateEndpoints:
return "UpdateEndpoints"
case globalRateLimitConfig.WebSocketEndpoints:
return "WebSocketEndpoints"
case globalRateLimitConfig.GeneralAPI:
return "GeneralAPI"
case globalRateLimitConfig.PublicEndpoints:
return "PublicEndpoints"
default:
return "unknown"
}
}
func TestGetRateLimiterForEndpoint_PriorityOrder(t *testing.T) {
// Test that more specific patterns match before general ones
InitializeRateLimiters()
tests := []struct {
name string
path string
method string
wantLimiterNm string
reason string
}{
{
name: "recovery takes priority over security status",
path: "/api/security/recovery",
method: http.MethodPost,
wantLimiterNm: "RecoveryEndpoints",
reason: "recovery paths should match before generic security paths",
},
{
name: "export takes priority over config read",
path: "/api/config/export",
method: http.MethodGet,
wantLimiterNm: "ExportEndpoints",
reason: "export paths should match before generic config read paths",
},
{
name: "auth path takes priority over other paths",
path: "/api/auth/callback",
method: http.MethodGet,
wantLimiterNm: "AuthEndpoints",
reason: "auth paths should be rate limited strictly",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetRateLimiterForEndpoint(tt.path, tt.method)
gotName := identifyLimiter(got)
if gotName != tt.wantLimiterNm {
t.Errorf("GetRateLimiterForEndpoint(%q, %q) = %s, want %s; %s",
tt.path, tt.method, gotName, tt.wantLimiterNm, tt.reason)
}
})
}
}
func TestGetRateLimiterForEndpoint_InitializesIfNeeded(t *testing.T) {
// Save and restore global state
saved := globalRateLimitConfig
globalRateLimitConfig = nil
t.Cleanup(func() {
globalRateLimitConfig = saved
})
// Call GetRateLimiterForEndpoint with nil config - should initialize
got := GetRateLimiterForEndpoint("/api/login", http.MethodPost)
if got == nil {
t.Fatal("GetRateLimiterForEndpoint returned nil after initialization")
}
if globalRateLimitConfig == nil {
t.Fatal("globalRateLimitConfig should have been initialized")
}
}
func TestUniversalRateLimitMiddleware_HeaderFormat(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := UniversalRateLimitMiddleware(handler)
limiter := GetRateLimiterForEndpoint("/api/login", http.MethodPost)
testIP := "203.0.113.55"
t.Cleanup(func() {
ResetRateLimitForIP(testIP)
})
makeRequest := func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/api/login", nil)
req.RemoteAddr = testIP + ":12345"
return req
}
// Make requests up to the limit - all should succeed
for i := 0; i < limiter.limit; i++ {
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, makeRequest())
if rr.Code != http.StatusOK {
t.Fatalf("expected OK while under limit, got %d on request %d", rr.Code, i+1)
}
}
// Next request should trigger rate limit
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, makeRequest())
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("expected rate-limit status, got %d", rr.Code)
}
// Verify X-RateLimit-Limit header is a proper decimal string
limitHeader := rr.Result().Header.Get("X-RateLimit-Limit")
expected := strconv.Itoa(limiter.limit)
if limitHeader != expected {
t.Fatalf("expected X-RateLimit-Limit %q, got %q", expected, limitHeader)
}
// Verify the header parses as a valid integer (not a control character)
if _, err := strconv.Atoi(limitHeader); err != nil {
t.Fatalf("header %q should parse as decimal: %v", limitHeader, err)
}
}
func TestUniversalRateLimitMiddleware_InitializesIfNeeded(t *testing.T) {
// Save and restore global state
saved := globalRateLimitConfig
globalRateLimitConfig = nil
t.Cleanup(func() {
globalRateLimitConfig = saved
})
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create middleware with nil config - should initialize
middleware := UniversalRateLimitMiddleware(handler)
if globalRateLimitConfig == nil {
t.Fatal("globalRateLimitConfig should have been initialized by middleware creation")
}
// Verify middleware works correctly after initialization
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", rr.Code)
}
}
func TestUniversalRateLimitMiddleware_StaticAssetBypass(t *testing.T) {
InitializeRateLimiters()
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := UniversalRateLimitMiddleware(handler)
// Static asset paths (not /api or /ws prefixed) should bypass rate limiting
tests := []struct {
path string
}{
{"/index.html"},
{"/static/js/app.js"},
{"/favicon.ico"},
{"/assets/style.css"},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
handlerCalled = false
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
rr := httptest.NewRecorder()
middleware.ServeHTTP(rr, req)
if !handlerCalled {
t.Errorf("handler should have been called for static asset path %s", tt.path)
}
if rr.Code != http.StatusOK {
t.Errorf("expected status 200 for static asset, got %d", rr.Code)
}
})
}
}
func TestResetRateLimitForIP(t *testing.T) {
t.Run("nil globalRateLimitConfig does not panic", func(t *testing.T) {
// Save current config and restore after test
savedConfig := globalRateLimitConfig
globalRateLimitConfig = nil
defer func() {
globalRateLimitConfig = savedConfig
}()
// This should not panic
ResetRateLimitForIP("192.168.1.1")
})
t.Run("resets rate limit for specific IP", func(t *testing.T) {
InitializeRateLimiters()
testIP := "10.0.0.50"
limiter := globalRateLimitConfig.AuthEndpoints
// Make some requests to add this IP to the limiter
for i := 0; i < 3; i++ {
limiter.Allow(testIP)
}
// Verify IP has attempts recorded
limiter.mu.Lock()
if _, exists := limiter.attempts[testIP]; !exists {
limiter.mu.Unlock()
t.Fatal("expected IP to have attempts recorded before reset")
}
limiter.mu.Unlock()
// Reset this specific IP
ResetRateLimitForIP(testIP)
// Verify IP is removed from all limiters
limiter.mu.Lock()
if _, exists := limiter.attempts[testIP]; exists {
limiter.mu.Unlock()
t.Fatal("expected IP to be cleared after reset")
}
limiter.mu.Unlock()
})
}