diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 6273be0e2..b07caa19f 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -80,6 +80,18 @@ func TestConsoleLogger_Close(t *testing.T) { } } +func TestConsoleLogger_Webhooks(t *testing.T) { + logger := NewConsoleLogger() + + if urls := logger.GetWebhookURLs(); len(urls) != 0 { + t.Fatalf("expected no webhook URLs, got %v", urls) + } + + if err := logger.UpdateWebhookURLs([]string{"https://example.com"}); err != nil { + t.Fatalf("UpdateWebhookURLs returned error: %v", err) + } +} + func TestSetLogger_GetLogger(t *testing.T) { // Create a custom logger for testing customLogger := NewConsoleLogger() diff --git a/pkg/audit/webhook_delivery_test.go b/pkg/audit/webhook_delivery_test.go new file mode 100644 index 000000000..470c876d4 --- /dev/null +++ b/pkg/audit/webhook_delivery_test.go @@ -0,0 +1,160 @@ +package audit + +import ( + "context" + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +func TestWebhookDeliveryDeliverWithRetry(t *testing.T) { + origResolver := resolveWebhookIPs + origBackoff := webhookBackoff + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP("8.8.8.8")}}, nil + } + webhookBackoff = []time.Duration{0, 0, 0} + t.Cleanup(func() { + resolveWebhookIPs = origResolver + webhookBackoff = origBackoff + }) + + var attempts int + event := Event{ + ID: "evt-1", + EventType: "login", + Timestamp: time.Unix(123, 0), + User: "user", + IP: "10.0.0.1", + Path: "/api/login", + Success: true, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected application/json content-type, got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != "Pulse-Audit-Webhook/1.0" { + t.Fatalf("unexpected user-agent %q", ua) + } + if r.Header.Get("X-Pulse-Event") != event.EventType { + t.Fatalf("expected event header %q, got %q", event.EventType, r.Header.Get("X-Pulse-Event")) + } + if r.Header.Get("X-Pulse-Event-ID") != event.ID { + t.Fatalf("expected event id header %q, got %q", event.ID, r.Header.Get("X-Pulse-Event-ID")) + } + + var payload WebhookPayload + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("failed decoding payload: %v", err) + } + if payload.Event != "audit."+event.EventType { + t.Fatalf("expected payload event %q, got %q", "audit."+event.EventType, payload.Event) + } + if payload.Data.ID != event.ID { + t.Fatalf("expected payload event id %q, got %q", event.ID, payload.Data.ID) + } + + if attempts < 3 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed parsing server URL: %v", err) + } + targetHost := "example.com" + + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if strings.HasPrefix(addr, targetHost) { + return (&net.Dialer{}).DialContext(ctx, network, serverURL.Host) + } + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + } + + delivery := NewWebhookDelivery([]string{"http://" + targetHost + "/audit"}) + delivery.client = &http.Client{Transport: transport} + + if err := delivery.deliverWithRetry("http://"+targetHost+"/audit", event); err != nil { + t.Fatalf("expected delivery to succeed, got %v", err) + } + if attempts != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts) + } +} + +func TestWebhookDeliveryDeliverWithRetryFails(t *testing.T) { + origResolver := resolveWebhookIPs + origBackoff := webhookBackoff + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP("8.8.8.8")}}, nil + } + webhookBackoff = []time.Duration{0, 0, 0} + t.Cleanup(func() { + resolveWebhookIPs = origResolver + webhookBackoff = origBackoff + }) + + var attempts int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed parsing server URL: %v", err) + } + targetHost := "example.com" + + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if strings.HasPrefix(addr, targetHost) { + return (&net.Dialer{}).DialContext(ctx, network, serverURL.Host) + } + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + } + + delivery := NewWebhookDelivery([]string{"http://" + targetHost + "/audit"}) + delivery.client = &http.Client{Transport: transport} + + err = delivery.deliverWithRetry("http://"+targetHost+"/audit", Event{ + ID: "evt-2", + EventType: "logout", + Timestamp: time.Now(), + IP: "10.0.0.2", + Success: true, + }) + if err == nil || !strings.Contains(err.Error(), "status 500") { + t.Fatalf("expected status error, got %v", err) + } + if attempts != webhookMaxRetries+1 { + t.Fatalf("expected %d attempts, got %d", webhookMaxRetries+1, attempts) + } +} + +func TestWebhookDeliveryDeliverInvalidURL(t *testing.T) { + delivery := NewWebhookDelivery([]string{}) + + err := delivery.deliver("://bad-url", Event{ID: "evt-3", EventType: "login", Timestamp: time.Now()}) + if err == nil || !strings.Contains(err.Error(), "webhook URL blocked") { + t.Fatalf("expected URL blocked error, got %v", err) + } +} diff --git a/pkg/audit/webhook_validation_test.go b/pkg/audit/webhook_validation_test.go index d3bac6e50..36537ea9b 100644 --- a/pkg/audit/webhook_validation_test.go +++ b/pkg/audit/webhook_validation_test.go @@ -33,16 +33,28 @@ func TestValidateWebhookURL(t *testing.T) { if err := validateWebhookURL(context.Background(), "http://127.0.0.1"); err == nil { t.Fatalf("expected error for loopback") } + if err := validateWebhookURL(context.Background(), "http://[::1]"); err == nil { + t.Fatalf("expected error for ipv6 loopback") + } if err := validateWebhookURL(context.Background(), "http://192.168.1.5"); err == nil { t.Fatalf("expected error for private IP") } if err := validateWebhookURL(context.Background(), "http://metadata.google.internal"); err == nil { t.Fatalf("expected error for blocked hostname") } + if err := validateWebhookURL(context.Background(), "http://example.local"); err == nil { + t.Fatalf("expected error for .local hostname") + } + if err := validateWebhookURL(context.Background(), "http://internal.example.com"); err == nil { + t.Fatalf("expected error for internal hostname") + } if err := validateWebhookURL(context.Background(), "https://example.com"); err != nil { t.Fatalf("expected valid URL, got %v", err) } + if err := validateWebhookURL(nil, "https://example.com"); err != nil { + t.Fatalf("expected valid URL with nil context, got %v", err) + } resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { return nil, context.DeadlineExceeded @@ -51,6 +63,13 @@ func TestValidateWebhookURL(t *testing.T) { t.Fatalf("expected resolution error") } + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{}, nil + } + if err := validateWebhookURL(context.Background(), "https://example.com"); err == nil { + t.Fatalf("expected empty resolution error") + } + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil } @@ -65,6 +84,7 @@ func TestIsPrivateOrReservedIP(t *testing.T) { "10.0.0.1": true, "169.254.1.1": true, "0.0.0.0": true, + "::1": true, "8.8.8.8": false, } for ipStr, expected := range cases { @@ -96,3 +116,16 @@ func TestWebhookDelivery_QueueAndURLs(t *testing.T) { t.Fatalf("expected URLs to be copied defensively") } } + +func TestWebhookDeliveryEnqueueDropsWhenFull(t *testing.T) { + delivery := &WebhookDelivery{ + queue: make(chan Event, 1), + } + + delivery.Enqueue(Event{ID: "first", EventType: "login", Timestamp: time.Now()}) + delivery.Enqueue(Event{ID: "second", EventType: "login", Timestamp: time.Now()}) + + if delivery.QueueLength() != 1 { + t.Fatalf("expected queue to stay at capacity, got %d", delivery.QueueLength()) + } +} diff --git a/pkg/auth/authorizer_test.go b/pkg/auth/authorizer_test.go new file mode 100644 index 000000000..2f2c41902 --- /dev/null +++ b/pkg/auth/authorizer_test.go @@ -0,0 +1,118 @@ +package auth + +import ( + "context" + "testing" +) + +type testAuthorizer struct { + allowed bool + err error + seen struct { + action string + resource string + } +} + +func (t *testAuthorizer) Authorize(ctx context.Context, action string, resource string) (bool, error) { + t.seen.action = action + t.seen.resource = resource + return t.allowed, t.err +} + +type testToken struct { + scopes map[string]bool +} + +func (t testToken) HasScope(scope string) bool { + return t.scopes[scope] +} + +type adminAuthorizer struct { + admin string +} + +func (a *adminAuthorizer) Authorize(ctx context.Context, action string, resource string) (bool, error) { + return false, nil +} + +func (a *adminAuthorizer) SetAdminUser(username string) { + a.admin = username +} + +func TestContextUserHelpers(t *testing.T) { + ctx := WithUser(context.Background(), "alice") + if got := GetUser(ctx); got != "alice" { + t.Fatalf("expected user alice, got %q", got) + } + + if got := GetUser(context.Background()); got != "" { + t.Fatalf("expected empty user, got %q", got) + } +} + +func TestContextTokenHelpers(t *testing.T) { + token := testToken{scopes: map[string]bool{"read": true}} + ctx := WithAPIToken(context.Background(), token) + + got := GetAPIToken(ctx) + if got == nil || !got.HasScope("read") { + t.Fatalf("expected token with read scope") + } + + if GetAPIToken(context.Background()) != nil { + t.Fatalf("expected nil token") + } +} + +func TestGetAPITokenContextKey(t *testing.T) { + key := GetAPITokenContextKey() + token := testToken{scopes: map[string]bool{"write": true}} + ctx := context.WithValue(context.Background(), key, token) + + got := GetAPIToken(ctx) + if got == nil || !got.HasScope("write") { + t.Fatalf("expected token from context key") + } +} + +func TestSetAuthorizerAndHasPermission(t *testing.T) { + orig := GetAuthorizer() + defer SetAuthorizer(orig) + + custom := &testAuthorizer{allowed: true} + SetAuthorizer(custom) + + if !HasPermission(context.Background(), "read", "nodes") { + t.Fatalf("expected permission to be allowed") + } + if custom.seen.action != "read" || custom.seen.resource != "nodes" { + t.Fatalf("expected authorizer to see read/nodes, got %q/%q", custom.seen.action, custom.seen.resource) + } +} + +func TestSetAdminUser(t *testing.T) { + orig := GetAuthorizer() + defer SetAuthorizer(orig) + + admin := &adminAuthorizer{} + SetAuthorizer(admin) + + SetAdminUser("") + if admin.admin != "" { + t.Fatalf("expected empty admin, got %q", admin.admin) + } + + SetAdminUser("root") + if admin.admin != "root" { + t.Fatalf("expected admin root, got %q", admin.admin) + } +} + +func TestSetAdminUserNonConfigurable(t *testing.T) { + orig := GetAuthorizer() + defer SetAuthorizer(orig) + + SetAuthorizer(&DefaultAuthorizer{}) + SetAdminUser("root") +} diff --git a/pkg/auth/rbac_global_test.go b/pkg/auth/rbac_global_test.go new file mode 100644 index 000000000..102cf4172 --- /dev/null +++ b/pkg/auth/rbac_global_test.go @@ -0,0 +1,64 @@ +package auth + +import "testing" + +type dummyManager struct{} + +func (d dummyManager) GetRoles() []Role { return nil } +func (d dummyManager) GetRole(id string) (Role, bool) { return Role{}, false } +func (d dummyManager) SaveRole(role Role) error { return nil } +func (d dummyManager) DeleteRole(id string) error { return nil } +func (d dummyManager) GetUserAssignments() []UserRoleAssignment { return nil } +func (d dummyManager) GetUserAssignment(username string) (UserRoleAssignment, bool) { + return UserRoleAssignment{}, false +} +func (d dummyManager) AssignRole(username string, roleID string) error { return nil } +func (d dummyManager) UpdateUserRoles(username string, roleIDs []string) error { + return nil +} +func (d dummyManager) RemoveRole(username string, roleID string) error { return nil } +func (d dummyManager) GetUserPermissions(username string) []Permission { return nil } + +type dummyExtendedManager struct { + dummyManager +} + +func (d dummyExtendedManager) GetRoleWithInheritance(id string) (Role, []Permission, bool) { + return Role{}, nil, false +} +func (d dummyExtendedManager) GetRolesWithInheritance(username string) []Role { return nil } +func (d dummyExtendedManager) GetChangeLogs(limit int, offset int) []RBACChangeLog { + return nil +} +func (d dummyExtendedManager) GetChangeLogsForEntity(entityType, entityID string) []RBACChangeLog { + return nil +} +func (d dummyExtendedManager) SaveRoleWithContext(role Role, username string) error { + return nil +} +func (d dummyExtendedManager) DeleteRoleWithContext(id string, username string) error { + return nil +} +func (d dummyExtendedManager) UpdateUserRolesWithContext(username string, roleIDs []string, byUser string) error { + return nil +} + +func TestGetExtendedManager(t *testing.T) { + orig := GetManager() + t.Cleanup(func() { SetManager(orig) }) + + base := &dummyManager{} + SetManager(base) + if GetManager() != base { + t.Fatalf("expected GetManager to return the set manager") + } + if GetExtendedManager() != nil { + t.Fatalf("expected nil extended manager") + } + + extended := dummyExtendedManager{} + SetManager(extended) + if GetExtendedManager() == nil { + t.Fatalf("expected extended manager") + } +} diff --git a/pkg/discovery/generate_ips_test.go b/pkg/discovery/generate_ips_test.go new file mode 100644 index 000000000..218922d45 --- /dev/null +++ b/pkg/discovery/generate_ips_test.go @@ -0,0 +1,66 @@ +package discovery + +import ( + "net" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/pkg/discovery/envdetect" +) + +func TestGenerateIPs(t *testing.T) { + scanner := &Scanner{policy: envdetect.DefaultScanPolicy()} + + _, subnet30, err := net.ParseCIDR("192.168.0.0/30") + if err != nil { + t.Fatalf("failed to parse subnet: %v", err) + } + ips := scanner.generateIPs(subnet30) + if len(ips) != 2 || ips[0] != "192.168.0.1" || ips[1] != "192.168.0.2" { + t.Fatalf("unexpected /30 IPs: %v", ips) + } + + _, subnet32, err := net.ParseCIDR("10.0.0.5/32") + if err != nil { + t.Fatalf("failed to parse /32 subnet: %v", err) + } + ips = scanner.generateIPs(subnet32) + if len(ips) != 1 || ips[0] != "10.0.0.5" { + t.Fatalf("unexpected /32 IPs: %v", ips) + } + + _, subnet31, err := net.ParseCIDR("10.0.0.0/31") + if err != nil { + t.Fatalf("failed to parse /31 subnet: %v", err) + } + ips = scanner.generateIPs(subnet31) + if len(ips) != 2 || ips[0] != "10.0.0.0" || ips[1] != "10.0.0.1" { + t.Fatalf("unexpected /31 IPs: %v", ips) + } +} + +func TestGenerateIPsRespectsLimit(t *testing.T) { + policy := envdetect.DefaultScanPolicy() + policy.MaxHostsPerScan = 2 + scanner := &Scanner{policy: policy} + + _, subnet29, err := net.ParseCIDR("192.168.1.0/29") + if err != nil { + t.Fatalf("failed to parse subnet: %v", err) + } + ips := scanner.generateIPs(subnet29) + if len(ips) != 2 || ips[0] != "192.168.1.1" || ips[1] != "192.168.1.2" { + t.Fatalf("unexpected limited IPs: %v", ips) + } +} + +func TestGenerateIPsIPv6ReturnsNil(t *testing.T) { + scanner := &Scanner{policy: envdetect.DefaultScanPolicy()} + + _, subnet6, err := net.ParseCIDR("2001:db8::/64") + if err != nil { + t.Fatalf("failed to parse ipv6 subnet: %v", err) + } + if ips := scanner.generateIPs(subnet6); ips != nil { + t.Fatalf("expected nil for ipv6 subnet, got %v", ips) + } +} diff --git a/pkg/discovery/scan_helpers_test.go b/pkg/discovery/scan_helpers_test.go new file mode 100644 index 000000000..deda56cdc --- /dev/null +++ b/pkg/discovery/scan_helpers_test.go @@ -0,0 +1,157 @@ +package discovery + +import ( + "context" + "net" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/pkg/discovery/envdetect" +) + +func TestCollectExtraTargets(t *testing.T) { + scanner := &Scanner{policy: envdetect.DefaultScanPolicy()} + seen := map[string]struct{}{ + "10.0.0.2": {}, + } + + profile := &envdetect.EnvironmentProfile{ + ExtraTargets: []net.IP{ + net.ParseIP("10.0.0.1"), + net.ParseIP("10.0.0.2"), + net.ParseIP("2001:db8::1"), + nil, + }, + } + + targets := scanner.collectExtraTargets(profile, seen) + if len(targets) != 1 || targets[0] != "10.0.0.1" { + t.Fatalf("unexpected targets: %v", targets) + } + if _, ok := seen["10.0.0.1"]; !ok { + t.Fatalf("expected seen to include new target") + } +} + +func TestCollectExtraTargetsNilProfile(t *testing.T) { + scanner := &Scanner{policy: envdetect.DefaultScanPolicy()} + seen := map[string]struct{}{} + + if targets := scanner.collectExtraTargets(nil, seen); targets != nil { + t.Fatalf("expected nil targets for nil profile, got %v", targets) + } +} + +func TestExpandPhaseIPs(t *testing.T) { + scanner := &Scanner{policy: envdetect.DefaultScanPolicy()} + seen := map[string]struct{}{ + "192.168.1.1": {}, + } + + _, subnet30, err := net.ParseCIDR("192.168.1.0/30") + if err != nil { + t.Fatalf("failed to parse subnet: %v", err) + } + _, subnet6, err := net.ParseCIDR("2001:db8::/64") + if err != nil { + t.Fatalf("failed to parse ipv6 subnet: %v", err) + } + + targets, count := scanner.expandPhaseIPs(envdetect.SubnetPhase{ + Subnets: []net.IPNet{*subnet30, *subnet6}, + }, seen) + if count != 2 { + t.Fatalf("expected 2 subnets counted, got %d", count) + } + if len(targets) != 1 || targets[0] != "192.168.1.2" { + t.Fatalf("unexpected targets: %v", targets) + } +} + +func TestShouldSkipPhase(t *testing.T) { + policy := envdetect.DefaultScanPolicy() + policy.DialTimeout = time.Second + scanner := &Scanner{policy: policy} + + ctxShort, cancel := context.WithDeadline(context.Background(), time.Now().Add(500*time.Millisecond)) + defer cancel() + + phaseLowConfidence := envdetect.SubnetPhase{Name: "low", Confidence: 0.2} + if !scanner.shouldSkipPhase(ctxShort, phaseLowConfidence) { + t.Fatalf("expected phase to be skipped with short deadline") + } + + phaseHighConfidence := envdetect.SubnetPhase{Name: "high", Confidence: 0.8} + if scanner.shouldSkipPhase(ctxShort, phaseHighConfidence) { + t.Fatalf("expected high confidence phase to run") + } + + if scanner.shouldSkipPhase(context.Background(), phaseLowConfidence) { + t.Fatalf("expected phase to run without deadline") + } +} + +func TestShouldSkipPhaseDefaultBudget(t *testing.T) { + policy := envdetect.DefaultScanPolicy() + policy.DialTimeout = 0 + scanner := &Scanner{policy: policy} + + ctxShort, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) + defer cancel() + + phaseLowConfidence := envdetect.SubnetPhase{Name: "low", Confidence: 0.2} + if !scanner.shouldSkipPhase(ctxShort, phaseLowConfidence) { + t.Fatalf("expected phase to be skipped with default budget") + } +} + +func TestBuildEnvironmentInfoCopiesData(t *testing.T) { + profile := &envdetect.EnvironmentProfile{ + Type: envdetect.DockerBridge, + Confidence: 0.75, + Warnings: []string{"warning-one"}, + Metadata: map[string]string{ + "container_type": "docker", + }, + Phases: []envdetect.SubnetPhase{ + { + Name: "phase-a", + Confidence: 0.9, + Subnets: []net.IPNet{}, + }, + }, + } + + info := buildEnvironmentInfo(profile) + if info == nil { + t.Fatal("expected environment info, got nil") + } + if info.Type != "docker_bridge" || info.Confidence != 0.75 { + t.Fatalf("unexpected info: %+v", info) + } + if len(info.Warnings) != 1 || info.Warnings[0] != "warning-one" { + t.Fatalf("unexpected warnings: %v", info.Warnings) + } + if info.Metadata["container_type"] != "docker" { + t.Fatalf("unexpected metadata: %v", info.Metadata) + } + if len(info.Phases) != 1 || info.Phases[0].Name != "phase-a" { + t.Fatalf("unexpected phase info: %v", info.Phases) + } + + info.Metadata["container_type"] = "mutated" + info.Warnings[0] = "mutated-warning" + + if profile.Metadata["container_type"] != "docker" { + t.Fatalf("expected metadata copy, got %v", profile.Metadata) + } + if profile.Warnings[0] != "warning-one" { + t.Fatalf("expected warnings copy, got %v", profile.Warnings) + } +} + +func TestBuildEnvironmentInfoNilProfile(t *testing.T) { + if info := buildEnvironmentInfo(nil); info != nil { + t.Fatalf("expected nil info for nil profile, got %+v", info) + } +} diff --git a/pkg/fsfilters/filters_test.go b/pkg/fsfilters/filters_test.go index edebf9555..bf0f6e687 100644 --- a/pkg/fsfilters/filters_test.go +++ b/pkg/fsfilters/filters_test.go @@ -266,3 +266,57 @@ func TestMatchesUserExclude(t *testing.T) { }) } } + +func TestMatchesDiskExclude(t *testing.T) { + tests := []struct { + name string + device string + mountpoint string + patterns []string + expected bool + }{ + {"empty patterns", "/dev/sda", "/mnt/data", nil, false}, + {"mountpoint exact match", "/dev/sdb", "/mnt/backup", []string{"/mnt/backup"}, true}, + {"mountpoint prefix match", "/dev/sdb", "/mnt/external-drive", []string{"/mnt/ext*"}, true}, + {"device exact match", "/dev/sda", "/mnt/data", []string{"/dev/sda"}, true}, + {"device name match", "/dev/nvme0n1", "/mnt/fast", []string{"nvme0n1"}, true}, + {"device contains match", "/dev/nvme1n1", "/mnt/fast", []string{"*nvme*"}, true}, + {"no match", "/dev/sdc", "/mnt/data", []string{"/mnt/backup", "/dev/sda"}, false}, + {"device without prefix match", "sdd", "/mnt/data", []string{"sdd"}, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := MatchesDiskExclude(tc.device, tc.mountpoint, tc.patterns) + if result != tc.expected { + t.Errorf("MatchesDiskExclude(%q, %q, %v) = %t, want %t", tc.device, tc.mountpoint, tc.patterns, result, tc.expected) + } + }) + } +} + +func TestMatchesDeviceExclude(t *testing.T) { + tests := []struct { + name string + device string + patterns []string + expected bool + }{ + {"empty patterns", "/dev/sda", nil, false}, + {"exact path match", "/dev/sda", []string{"/dev/sda"}, true}, + {"exact name match", "/dev/sda", []string{"sda"}, true}, + {"prefix added match", "sdb", []string{"/dev/sdb"}, true}, + {"contains match", "/dev/nvme0n1", []string{"*nvme*"}, true}, + {"whitespace pattern", "/dev/sdc", []string{" /dev/sdc "}, true}, + {"no match", "/dev/sdd", []string{"/dev/sde"}, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := MatchesDeviceExclude(tc.device, tc.patterns) + if result != tc.expected { + t.Errorf("MatchesDeviceExclude(%q, %v) = %t, want %t", tc.device, tc.patterns, result, tc.expected) + } + }) + } +} diff --git a/pkg/metrics/alert_metrics_test.go b/pkg/metrics/alert_metrics_test.go index bd5f8b550..a6a28f99d 100644 --- a/pkg/metrics/alert_metrics_test.go +++ b/pkg/metrics/alert_metrics_test.go @@ -1,6 +1,50 @@ package metrics -import "testing" +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" +) + +func TestRecordAlertFired(t *testing.T) { + alert := &alerts.Alert{ + ID: "test-alert-1", + Level: alerts.AlertLevelWarning, + Type: "container_cpu", + StartTime: time.Now(), + } + + // Should not panic + RecordAlertFired(alert) +} + +func TestRecordAlertResolved(t *testing.T) { + now := time.Now() + alert := &alerts.Alert{ + ID: "test-alert-2", + Level: alerts.AlertLevelWarning, + Type: "container_memory", + StartTime: now.Add(-5 * time.Minute), + LastSeen: now, + } + + // Should not panic + RecordAlertResolved(alert) +} + +func TestRecordAlertAcknowledged(t *testing.T) { + // Should not panic + RecordAlertAcknowledged() +} + +func TestRecordAlertSuppressed(t *testing.T) { + // Should not panic with various reasons + RecordAlertSuppressed("quiet_hours") + RecordAlertSuppressed("rate_limit") + RecordAlertSuppressed("duplicate") +} func TestMetricVectors_NotNil(t *testing.T) { // Verify that metric vectors are properly initialized @@ -26,3 +70,52 @@ func TestMetricVectors_NotNil(t *testing.T) { t.Error("AlertsRateLimitedTotal should not be nil") } } + +func TestAlertMetricsIncrements(t *testing.T) { + alert := &alerts.Alert{ + ID: "metrics-test", + Level: alerts.AlertLevelCritical, + Type: "unit_test_alert", + StartTime: time.Now().Add(-time.Minute), + LastSeen: time.Now(), + } + + fired := AlertsFiredTotal.WithLabelValues(string(alert.Level), alert.Type) + active := AlertsActive.WithLabelValues(string(alert.Level), alert.Type) + resolved := AlertsResolvedTotal.WithLabelValues(alert.Type) + + firedBefore := testutil.ToFloat64(fired) + activeBefore := testutil.ToFloat64(active) + resolvedBefore := testutil.ToFloat64(resolved) + + RecordAlertFired(alert) + + if testutil.ToFloat64(fired) != firedBefore+1 { + t.Fatalf("expected fired counter increment") + } + if testutil.ToFloat64(active) != activeBefore+1 { + t.Fatalf("expected active gauge increment") + } + + RecordAlertResolved(alert) + + if testutil.ToFloat64(resolved) != resolvedBefore+1 { + t.Fatalf("expected resolved counter increment") + } + if testutil.ToFloat64(active) != activeBefore { + t.Fatalf("expected active gauge to return to baseline") + } + + ackBefore := testutil.ToFloat64(AlertsAcknowledgedTotal) + RecordAlertAcknowledged() + if testutil.ToFloat64(AlertsAcknowledgedTotal) != ackBefore+1 { + t.Fatalf("expected acknowledged counter increment") + } + + suppressed := AlertsSuppressedTotal.WithLabelValues("unit_test_reason") + suppressedBefore := testutil.ToFloat64(suppressed) + RecordAlertSuppressed("unit_test_reason") + if testutil.ToFloat64(suppressed) != suppressedBefore+1 { + t.Fatalf("expected suppressed counter increment") + } +} diff --git a/pkg/metrics/store_downsample_test.go b/pkg/metrics/store_downsample_test.go new file mode 100644 index 000000000..360a81825 --- /dev/null +++ b/pkg/metrics/store_downsample_test.go @@ -0,0 +1,178 @@ +package metrics + +import ( + "math" + "testing" + "time" +) + +func TestStoreQueryAllDownsampling(t *testing.T) { + dir := t.TempDir() + store, err := NewStore(DefaultConfig(dir)) + if err != nil { + t.Fatalf("NewStore returned error: %v", err) + } + defer store.Close() + + start := time.Unix(1000, 0) + batch := make([]bufferedMetric, 0, 20) + for i := 0; i < 10; i++ { + ts := start.Add(time.Duration(i) * time.Minute) + batch = append(batch, + bufferedMetric{resourceType: "vm", resourceID: "v1", metricType: "cpu", value: float64(i), timestamp: ts, tier: TierRaw}, + bufferedMetric{resourceType: "vm", resourceID: "v1", metricType: "mem", value: float64(100 + i), timestamp: ts, tier: TierRaw}, + ) + } + store.writeBatch(batch) + + result, err := store.QueryAll("vm", "v1", start.Add(-time.Hour), start.Add(time.Hour), 300) + if err != nil { + t.Fatalf("QueryAll downsampled failed: %v", err) + } + + cpu := result["cpu"] + mem := result["mem"] + if len(cpu) != 3 || len(mem) != 3 { + t.Fatalf("expected 3 bucketed points per metric, got cpu=%d mem=%d", len(cpu), len(mem)) + } + + assertPoint := func(point MetricPoint, ts int64, value, min, max float64) { + t.Helper() + if point.Timestamp.Unix() != ts { + t.Fatalf("expected bucket timestamp %d, got %d", ts, point.Timestamp.Unix()) + } + if math.Abs(point.Value-value) > 0.0001 { + t.Fatalf("expected value %v, got %v", value, point.Value) + } + if math.Abs(point.Min-min) > 0.0001 { + t.Fatalf("expected min %v, got %v", min, point.Min) + } + if math.Abs(point.Max-max) > 0.0001 { + t.Fatalf("expected max %v, got %v", max, point.Max) + } + } + + assertPoint(cpu[0], 1050, 1.5, 0, 3) + assertPoint(cpu[1], 1350, 6, 4, 8) + assertPoint(cpu[2], 1650, 9, 9, 9) + + assertPoint(mem[0], 1050, 101.5, 100, 103) + assertPoint(mem[1], 1350, 106, 104, 108) + assertPoint(mem[2], 1650, 109, 109, 109) +} + +func TestStoreTierFallbacks(t *testing.T) { + dir := t.TempDir() + store, err := NewStore(DefaultConfig(dir)) + if err != nil { + t.Fatalf("NewStore returned error: %v", err) + } + defer store.Close() + + tests := []struct { + name string + duration time.Duration + expected []Tier + }{ + {"raw", 30 * time.Minute, []Tier{TierRaw, TierMinute, TierHourly}}, + {"minute", 3 * time.Hour, []Tier{TierMinute, TierRaw, TierHourly}}, + {"hourly", 2 * 24 * time.Hour, []Tier{TierHourly, TierMinute, TierRaw}}, + {"daily", 30 * 24 * time.Hour, []Tier{TierDaily, TierHourly, TierMinute, TierRaw}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := store.tierFallbacks(tc.duration) + if len(got) != len(tc.expected) { + t.Fatalf("expected %d tiers, got %d (%v)", len(tc.expected), len(got), got) + } + for i := range got { + if got[i] != tc.expected[i] { + t.Fatalf("expected %v, got %v", tc.expected, got) + } + } + }) + } +} + +func TestStoreMetadataHelpers(t *testing.T) { + dir := t.TempDir() + store, err := NewStore(DefaultConfig(dir)) + if err != nil { + t.Fatalf("NewStore returned error: %v", err) + } + defer store.Close() + + if value, ok := store.getMetaInt("missing"); ok { + t.Fatalf("expected missing meta to return ok=false, got %d", value) + } + + if ts, ok := store.getMaxTimestampForTier(TierRaw); ok || ts != 0 { + t.Fatalf("expected no max timestamp, got %d (ok=%t)", ts, ok) + } + + _, err = store.db.Exec( + `INSERT INTO metrics (resource_type, resource_id, metric_type, value, timestamp, tier) VALUES + ('vm','vm-1','cpu',1.0,?, 'raw'), + ('vm','vm-1','cpu',2.0,?, 'raw')`, + 100, 200, + ) + if err != nil { + t.Fatalf("insert metrics returned error: %v", err) + } + + if ts, ok := store.getMaxTimestampForTier(TierRaw); !ok || ts != 200 { + t.Fatalf("expected max timestamp 200, got %d (ok=%t)", ts, ok) + } +} + +func TestStoreQueryDownsamplingStats(t *testing.T) { + dir := t.TempDir() + store, err := NewStore(DefaultConfig(dir)) + if err != nil { + t.Fatalf("NewStore returned error: %v", err) + } + defer store.Close() + + start := time.Unix(1000, 0) + store.writeBatch([]bufferedMetric{ + {resourceType: "vm", resourceID: "v2", metricType: "cpu", value: 10, timestamp: start, tier: TierRaw}, + {resourceType: "vm", resourceID: "v2", metricType: "cpu", value: 30, timestamp: start.Add(20 * time.Second), tier: TierRaw}, + {resourceType: "vm", resourceID: "v2", metricType: "cpu", value: 20, timestamp: start.Add(50 * time.Second), tier: TierRaw}, + }) + + points, err := store.Query("vm", "v2", "cpu", start.Add(-time.Minute), start.Add(time.Minute), 120) + if err != nil { + t.Fatalf("Query downsampled failed: %v", err) + } + if len(points) != 1 { + t.Fatalf("expected 1 bucketed point, got %d", len(points)) + } + + point := points[0] + if point.Timestamp.Unix() != 1020 { + t.Fatalf("expected bucket timestamp 1020, got %d", point.Timestamp.Unix()) + } + if point.Value != 20 || point.Min != 10 || point.Max != 30 { + t.Fatalf("unexpected stats: value=%v min=%v max=%v", point.Value, point.Min, point.Max) + } +} + +func TestStoreFlushLockedDropsWhenChannelFull(t *testing.T) { + store := &Store{ + config: StoreConfig{WriteBufferSize: 1}, + buffer: []bufferedMetric{{resourceType: "vm", resourceID: "v3", metricType: "cpu", value: 1, timestamp: time.Now(), tier: TierRaw}}, + writeCh: make(chan []bufferedMetric), + } + + store.bufferMu.Lock() + store.flushLocked() + store.bufferMu.Unlock() + + if len(store.buffer) != 0 { + t.Fatalf("expected buffer to be cleared, got %d", len(store.buffer)) + } + if len(store.writeCh) != 0 { + t.Fatalf("expected write channel to remain empty, got %d", len(store.writeCh)) + } +} diff --git a/pkg/pmg/client_test.go b/pkg/pmg/client_test.go index c97b641aa..915b8e315 100644 --- a/pkg/pmg/client_test.go +++ b/pkg/pmg/client_test.go @@ -417,3 +417,321 @@ func TestMailEndpointsHandleNullAndStringValues(t *testing.T) { t.Fatalf("expected oldest age 600, got %d", queue.OldestAge.Int64()) } } + +func TestClientTokenNameIncludesUserAndRealm(t *testing.T) { + t.Parallel() + + var authHeader string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api2/json/statistics/mail": + authHeader = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":{"count":1}}`) + default: + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apiuser@custom!apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + stats, err := client.GetMailStatistics(context.Background(), "") + if err != nil { + t.Fatalf("get mail statistics failed: %v", err) + } + if stats == nil || stats.Count.Float64() != 1 { + t.Fatalf("expected statistics count 1, got %+v", stats) + } + + expected := "PMGAPIToken=apiuser@custom!apitoken:secret" + if authHeader != expected { + t.Fatalf("expected authorization header %q, got %q", expected, authHeader) + } +} + +func TestClientRequestAuthError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api2/json/statistics/mail": + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, "unauthorized") + default: + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + _, err = client.GetMailStatistics(context.Background(), "") + if err == nil || !strings.Contains(err.Error(), "authentication error") { + t.Fatalf("expected authentication error, got %v", err) + } +} + +func TestClientRequestNonAuthError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api2/json/statistics/mail": + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, "boom") + default: + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + _, err = client.GetMailStatistics(context.Background(), "") + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "API error 500") { + t.Fatalf("expected API error 500, got %q", msg) + } + if strings.Contains(strings.ToLower(msg), "authentication error") { + t.Fatalf("did not expect authentication error, got %q", msg) + } +} + +func TestClientGetVersionInvalidJSON(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api2/json/version": + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":`) + default: + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if _, err := client.GetVersion(context.Background()); err == nil || !strings.Contains(err.Error(), "failed to decode response") { + t.Fatalf("expected decode error, got %v", err) + } +} + +func TestClientMailCountTimespanParam(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api2/json/statistics/mailcount" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + if got := r.URL.Query().Get("timespan"); got != "3600" { + t.Fatalf("expected timespan=3600, got %q", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if _, err := client.GetMailCount(context.Background(), 3600); err != nil { + t.Fatalf("GetMailCount failed: %v", err) + } +} + +func TestClientClusterStatusListSingle(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api2/json/config/cluster/status" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + if got := r.URL.Query().Get("list_single_node"); got != "1" { + t.Fatalf("expected list_single_node=1, got %q", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if _, err := client.GetClusterStatus(context.Background(), true); err != nil { + t.Fatalf("GetClusterStatus failed: %v", err) + } +} + +func TestClientListBackupsEscapesNode(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api2/json/nodes/node/1/backup" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + if r.URL.EscapedPath() != "/api2/json/nodes/node%2F1/backup" { + t.Fatalf("expected escaped path, got %s", r.URL.EscapedPath()) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if _, err := client.ListBackups(context.Background(), "node/1"); err != nil { + t.Fatalf("ListBackups failed: %v", err) + } +} + +func TestClientGetSpamScores(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api2/json/statistics/spamscores" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[{"level":"high","count":"2","ratio":"0.5"}]}`) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + scores, err := client.GetSpamScores(context.Background()) + if err != nil { + t.Fatalf("GetSpamScores failed: %v", err) + } + if len(scores) != 1 || scores[0].Level != "high" || scores[0].Count.Int() != 2 { + t.Fatalf("unexpected spam scores: %+v", scores) + } +} + +func TestClientMailCountNoTimespanParam(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api2/json/statistics/mailcount" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + if got := r.URL.Query().Get("timespan"); got != "" { + t.Fatalf("expected no timespan param, got %q", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if _, err := client.GetMailCount(context.Background(), 0); err != nil { + t.Fatalf("GetMailCount failed: %v", err) + } +} + +func TestClientClusterStatusNoParam(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api2/json/config/cluster/status" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + if len(r.URL.Query()) != 0 { + t.Fatalf("expected no query params, got %v", r.URL.Query()) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "apitoken", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if _, err := client.GetClusterStatus(context.Background(), false); err != nil { + t.Fatalf("GetClusterStatus failed: %v", err) + } +} diff --git a/pkg/proxmox/client_request_test.go b/pkg/proxmox/client_request_test.go index 308d10d28..f9abf2649 100644 --- a/pkg/proxmox/client_request_test.go +++ b/pkg/proxmox/client_request_test.go @@ -118,3 +118,33 @@ func TestClientRequest_401Unauthorized(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestClientRequest_500NonAuth(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("boom")) + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Host: server.URL, + TokenName: "user@pve!token", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + _, err = client.get(context.Background(), "/nodes") + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "API error 500") { + t.Fatalf("expected api error 500, got %q", msg) + } + if strings.Contains(strings.ToLower(msg), "authentication error") { + t.Fatalf("did not expect authentication error for 500, got %q", msg) + } +} diff --git a/pkg/tlsutil/extra_test.go b/pkg/tlsutil/extra_test.go index 4aaa531d9..abd5906e8 100644 --- a/pkg/tlsutil/extra_test.go +++ b/pkg/tlsutil/extra_test.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -65,3 +66,17 @@ func TestFetchFingerprint(t *testing.T) { t.Fatalf("unexpected fingerprint: %s", fingerprint) } } + +func TestFetchFingerprintInvalidURL(t *testing.T) { + _, err := FetchFingerprint("http://[::1") + if err == nil || !strings.Contains(err.Error(), "failed to parse host URL") { + t.Fatalf("expected parse error, got %v", err) + } +} + +func TestFetchFingerprintConnectionError(t *testing.T) { + _, err := FetchFingerprint("https://127.0.0.1:1") + if err == nil || !strings.Contains(err.Error(), "failed to connect") { + t.Fatalf("expected connection error, got %v", err) + } +}