Merge branch 'main' into pulse/rc-00-scope-freeze

This commit is contained in:
rcourtman
2026-02-10 17:17:32 +00:00
13 changed files with 1299 additions and 1 deletions

View File

@@ -80,6 +80,18 @@ func TestConsoleLogger_Close(t *testing.T) {
}
}
func TestConsoleLogger_Webhooks(t *testing.T) {
logger := NewConsoleLogger()
if urls := logger.GetWebhookURLs(); len(urls) != 0 {
t.Fatalf("expected no webhook URLs, got %v", urls)
}
if err := logger.UpdateWebhookURLs([]string{"https://example.com"}); err != nil {
t.Fatalf("UpdateWebhookURLs returned error: %v", err)
}
}
func TestSetLogger_GetLogger(t *testing.T) {
// Create a custom logger for testing
customLogger := NewConsoleLogger()

View File

@@ -0,0 +1,160 @@
package audit
import (
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
func TestWebhookDeliveryDeliverWithRetry(t *testing.T) {
origResolver := resolveWebhookIPs
origBackoff := webhookBackoff
resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("8.8.8.8")}}, nil
}
webhookBackoff = []time.Duration{0, 0, 0}
t.Cleanup(func() {
resolveWebhookIPs = origResolver
webhookBackoff = origBackoff
})
var attempts int
event := Event{
ID: "evt-1",
EventType: "login",
Timestamp: time.Unix(123, 0),
User: "user",
IP: "10.0.0.1",
Path: "/api/login",
Success: true,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected application/json content-type, got %s", ct)
}
if ua := r.Header.Get("User-Agent"); ua != "Pulse-Audit-Webhook/1.0" {
t.Fatalf("unexpected user-agent %q", ua)
}
if r.Header.Get("X-Pulse-Event") != event.EventType {
t.Fatalf("expected event header %q, got %q", event.EventType, r.Header.Get("X-Pulse-Event"))
}
if r.Header.Get("X-Pulse-Event-ID") != event.ID {
t.Fatalf("expected event id header %q, got %q", event.ID, r.Header.Get("X-Pulse-Event-ID"))
}
var payload WebhookPayload
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("failed decoding payload: %v", err)
}
if payload.Event != "audit."+event.EventType {
t.Fatalf("expected payload event %q, got %q", "audit."+event.EventType, payload.Event)
}
if payload.Data.ID != event.ID {
t.Fatalf("expected payload event id %q, got %q", event.ID, payload.Data.ID)
}
if attempts < 3 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed parsing server URL: %v", err)
}
targetHost := "example.com"
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if strings.HasPrefix(addr, targetHost) {
return (&net.Dialer{}).DialContext(ctx, network, serverURL.Host)
}
return (&net.Dialer{}).DialContext(ctx, network, addr)
},
}
delivery := NewWebhookDelivery([]string{"http://" + targetHost + "/audit"})
delivery.client = &http.Client{Transport: transport}
if err := delivery.deliverWithRetry("http://"+targetHost+"/audit", event); err != nil {
t.Fatalf("expected delivery to succeed, got %v", err)
}
if attempts != 3 {
t.Fatalf("expected 3 attempts, got %d", attempts)
}
}
func TestWebhookDeliveryDeliverWithRetryFails(t *testing.T) {
origResolver := resolveWebhookIPs
origBackoff := webhookBackoff
resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("8.8.8.8")}}, nil
}
webhookBackoff = []time.Duration{0, 0, 0}
t.Cleanup(func() {
resolveWebhookIPs = origResolver
webhookBackoff = origBackoff
})
var attempts int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed parsing server URL: %v", err)
}
targetHost := "example.com"
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if strings.HasPrefix(addr, targetHost) {
return (&net.Dialer{}).DialContext(ctx, network, serverURL.Host)
}
return (&net.Dialer{}).DialContext(ctx, network, addr)
},
}
delivery := NewWebhookDelivery([]string{"http://" + targetHost + "/audit"})
delivery.client = &http.Client{Transport: transport}
err = delivery.deliverWithRetry("http://"+targetHost+"/audit", Event{
ID: "evt-2",
EventType: "logout",
Timestamp: time.Now(),
IP: "10.0.0.2",
Success: true,
})
if err == nil || !strings.Contains(err.Error(), "status 500") {
t.Fatalf("expected status error, got %v", err)
}
if attempts != webhookMaxRetries+1 {
t.Fatalf("expected %d attempts, got %d", webhookMaxRetries+1, attempts)
}
}
func TestWebhookDeliveryDeliverInvalidURL(t *testing.T) {
delivery := NewWebhookDelivery([]string{})
err := delivery.deliver("://bad-url", Event{ID: "evt-3", EventType: "login", Timestamp: time.Now()})
if err == nil || !strings.Contains(err.Error(), "webhook URL blocked") {
t.Fatalf("expected URL blocked error, got %v", err)
}
}

View File

@@ -33,16 +33,28 @@ func TestValidateWebhookURL(t *testing.T) {
if err := validateWebhookURL(context.Background(), "http://127.0.0.1"); err == nil {
t.Fatalf("expected error for loopback")
}
if err := validateWebhookURL(context.Background(), "http://[::1]"); err == nil {
t.Fatalf("expected error for ipv6 loopback")
}
if err := validateWebhookURL(context.Background(), "http://192.168.1.5"); err == nil {
t.Fatalf("expected error for private IP")
}
if err := validateWebhookURL(context.Background(), "http://metadata.google.internal"); err == nil {
t.Fatalf("expected error for blocked hostname")
}
if err := validateWebhookURL(context.Background(), "http://example.local"); err == nil {
t.Fatalf("expected error for .local hostname")
}
if err := validateWebhookURL(context.Background(), "http://internal.example.com"); err == nil {
t.Fatalf("expected error for internal hostname")
}
if err := validateWebhookURL(context.Background(), "https://example.com"); err != nil {
t.Fatalf("expected valid URL, got %v", err)
}
if err := validateWebhookURL(nil, "https://example.com"); err != nil {
t.Fatalf("expected valid URL with nil context, got %v", err)
}
resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return nil, context.DeadlineExceeded
@@ -51,6 +63,13 @@ func TestValidateWebhookURL(t *testing.T) {
t.Fatalf("expected resolution error")
}
resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{}, nil
}
if err := validateWebhookURL(context.Background(), "https://example.com"); err == nil {
t.Fatalf("expected empty resolution error")
}
resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil
}
@@ -65,6 +84,7 @@ func TestIsPrivateOrReservedIP(t *testing.T) {
"10.0.0.1": true,
"169.254.1.1": true,
"0.0.0.0": true,
"::1": true,
"8.8.8.8": false,
}
for ipStr, expected := range cases {
@@ -96,3 +116,16 @@ func TestWebhookDelivery_QueueAndURLs(t *testing.T) {
t.Fatalf("expected URLs to be copied defensively")
}
}
func TestWebhookDeliveryEnqueueDropsWhenFull(t *testing.T) {
delivery := &WebhookDelivery{
queue: make(chan Event, 1),
}
delivery.Enqueue(Event{ID: "first", EventType: "login", Timestamp: time.Now()})
delivery.Enqueue(Event{ID: "second", EventType: "login", Timestamp: time.Now()})
if delivery.QueueLength() != 1 {
t.Fatalf("expected queue to stay at capacity, got %d", delivery.QueueLength())
}
}

118
pkg/auth/authorizer_test.go Normal file
View File

@@ -0,0 +1,118 @@
package auth
import (
"context"
"testing"
)
type testAuthorizer struct {
allowed bool
err error
seen struct {
action string
resource string
}
}
func (t *testAuthorizer) Authorize(ctx context.Context, action string, resource string) (bool, error) {
t.seen.action = action
t.seen.resource = resource
return t.allowed, t.err
}
type testToken struct {
scopes map[string]bool
}
func (t testToken) HasScope(scope string) bool {
return t.scopes[scope]
}
type adminAuthorizer struct {
admin string
}
func (a *adminAuthorizer) Authorize(ctx context.Context, action string, resource string) (bool, error) {
return false, nil
}
func (a *adminAuthorizer) SetAdminUser(username string) {
a.admin = username
}
func TestContextUserHelpers(t *testing.T) {
ctx := WithUser(context.Background(), "alice")
if got := GetUser(ctx); got != "alice" {
t.Fatalf("expected user alice, got %q", got)
}
if got := GetUser(context.Background()); got != "" {
t.Fatalf("expected empty user, got %q", got)
}
}
func TestContextTokenHelpers(t *testing.T) {
token := testToken{scopes: map[string]bool{"read": true}}
ctx := WithAPIToken(context.Background(), token)
got := GetAPIToken(ctx)
if got == nil || !got.HasScope("read") {
t.Fatalf("expected token with read scope")
}
if GetAPIToken(context.Background()) != nil {
t.Fatalf("expected nil token")
}
}
func TestGetAPITokenContextKey(t *testing.T) {
key := GetAPITokenContextKey()
token := testToken{scopes: map[string]bool{"write": true}}
ctx := context.WithValue(context.Background(), key, token)
got := GetAPIToken(ctx)
if got == nil || !got.HasScope("write") {
t.Fatalf("expected token from context key")
}
}
func TestSetAuthorizerAndHasPermission(t *testing.T) {
orig := GetAuthorizer()
defer SetAuthorizer(orig)
custom := &testAuthorizer{allowed: true}
SetAuthorizer(custom)
if !HasPermission(context.Background(), "read", "nodes") {
t.Fatalf("expected permission to be allowed")
}
if custom.seen.action != "read" || custom.seen.resource != "nodes" {
t.Fatalf("expected authorizer to see read/nodes, got %q/%q", custom.seen.action, custom.seen.resource)
}
}
func TestSetAdminUser(t *testing.T) {
orig := GetAuthorizer()
defer SetAuthorizer(orig)
admin := &adminAuthorizer{}
SetAuthorizer(admin)
SetAdminUser("")
if admin.admin != "" {
t.Fatalf("expected empty admin, got %q", admin.admin)
}
SetAdminUser("root")
if admin.admin != "root" {
t.Fatalf("expected admin root, got %q", admin.admin)
}
}
func TestSetAdminUserNonConfigurable(t *testing.T) {
orig := GetAuthorizer()
defer SetAuthorizer(orig)
SetAuthorizer(&DefaultAuthorizer{})
SetAdminUser("root")
}

View File

@@ -0,0 +1,64 @@
package auth
import "testing"
type dummyManager struct{}
func (d dummyManager) GetRoles() []Role { return nil }
func (d dummyManager) GetRole(id string) (Role, bool) { return Role{}, false }
func (d dummyManager) SaveRole(role Role) error { return nil }
func (d dummyManager) DeleteRole(id string) error { return nil }
func (d dummyManager) GetUserAssignments() []UserRoleAssignment { return nil }
func (d dummyManager) GetUserAssignment(username string) (UserRoleAssignment, bool) {
return UserRoleAssignment{}, false
}
func (d dummyManager) AssignRole(username string, roleID string) error { return nil }
func (d dummyManager) UpdateUserRoles(username string, roleIDs []string) error {
return nil
}
func (d dummyManager) RemoveRole(username string, roleID string) error { return nil }
func (d dummyManager) GetUserPermissions(username string) []Permission { return nil }
type dummyExtendedManager struct {
dummyManager
}
func (d dummyExtendedManager) GetRoleWithInheritance(id string) (Role, []Permission, bool) {
return Role{}, nil, false
}
func (d dummyExtendedManager) GetRolesWithInheritance(username string) []Role { return nil }
func (d dummyExtendedManager) GetChangeLogs(limit int, offset int) []RBACChangeLog {
return nil
}
func (d dummyExtendedManager) GetChangeLogsForEntity(entityType, entityID string) []RBACChangeLog {
return nil
}
func (d dummyExtendedManager) SaveRoleWithContext(role Role, username string) error {
return nil
}
func (d dummyExtendedManager) DeleteRoleWithContext(id string, username string) error {
return nil
}
func (d dummyExtendedManager) UpdateUserRolesWithContext(username string, roleIDs []string, byUser string) error {
return nil
}
func TestGetExtendedManager(t *testing.T) {
orig := GetManager()
t.Cleanup(func() { SetManager(orig) })
base := &dummyManager{}
SetManager(base)
if GetManager() != base {
t.Fatalf("expected GetManager to return the set manager")
}
if GetExtendedManager() != nil {
t.Fatalf("expected nil extended manager")
}
extended := dummyExtendedManager{}
SetManager(extended)
if GetExtendedManager() == nil {
t.Fatalf("expected extended manager")
}
}

View File

@@ -0,0 +1,66 @@
package discovery
import (
"net"
"testing"
"github.com/rcourtman/pulse-go-rewrite/pkg/discovery/envdetect"
)
func TestGenerateIPs(t *testing.T) {
scanner := &Scanner{policy: envdetect.DefaultScanPolicy()}
_, subnet30, err := net.ParseCIDR("192.168.0.0/30")
if err != nil {
t.Fatalf("failed to parse subnet: %v", err)
}
ips := scanner.generateIPs(subnet30)
if len(ips) != 2 || ips[0] != "192.168.0.1" || ips[1] != "192.168.0.2" {
t.Fatalf("unexpected /30 IPs: %v", ips)
}
_, subnet32, err := net.ParseCIDR("10.0.0.5/32")
if err != nil {
t.Fatalf("failed to parse /32 subnet: %v", err)
}
ips = scanner.generateIPs(subnet32)
if len(ips) != 1 || ips[0] != "10.0.0.5" {
t.Fatalf("unexpected /32 IPs: %v", ips)
}
_, subnet31, err := net.ParseCIDR("10.0.0.0/31")
if err != nil {
t.Fatalf("failed to parse /31 subnet: %v", err)
}
ips = scanner.generateIPs(subnet31)
if len(ips) != 2 || ips[0] != "10.0.0.0" || ips[1] != "10.0.0.1" {
t.Fatalf("unexpected /31 IPs: %v", ips)
}
}
func TestGenerateIPsRespectsLimit(t *testing.T) {
policy := envdetect.DefaultScanPolicy()
policy.MaxHostsPerScan = 2
scanner := &Scanner{policy: policy}
_, subnet29, err := net.ParseCIDR("192.168.1.0/29")
if err != nil {
t.Fatalf("failed to parse subnet: %v", err)
}
ips := scanner.generateIPs(subnet29)
if len(ips) != 2 || ips[0] != "192.168.1.1" || ips[1] != "192.168.1.2" {
t.Fatalf("unexpected limited IPs: %v", ips)
}
}
func TestGenerateIPsIPv6ReturnsNil(t *testing.T) {
scanner := &Scanner{policy: envdetect.DefaultScanPolicy()}
_, subnet6, err := net.ParseCIDR("2001:db8::/64")
if err != nil {
t.Fatalf("failed to parse ipv6 subnet: %v", err)
}
if ips := scanner.generateIPs(subnet6); ips != nil {
t.Fatalf("expected nil for ipv6 subnet, got %v", ips)
}
}

View File

@@ -0,0 +1,157 @@
package discovery
import (
"context"
"net"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/pkg/discovery/envdetect"
)
func TestCollectExtraTargets(t *testing.T) {
scanner := &Scanner{policy: envdetect.DefaultScanPolicy()}
seen := map[string]struct{}{
"10.0.0.2": {},
}
profile := &envdetect.EnvironmentProfile{
ExtraTargets: []net.IP{
net.ParseIP("10.0.0.1"),
net.ParseIP("10.0.0.2"),
net.ParseIP("2001:db8::1"),
nil,
},
}
targets := scanner.collectExtraTargets(profile, seen)
if len(targets) != 1 || targets[0] != "10.0.0.1" {
t.Fatalf("unexpected targets: %v", targets)
}
if _, ok := seen["10.0.0.1"]; !ok {
t.Fatalf("expected seen to include new target")
}
}
func TestCollectExtraTargetsNilProfile(t *testing.T) {
scanner := &Scanner{policy: envdetect.DefaultScanPolicy()}
seen := map[string]struct{}{}
if targets := scanner.collectExtraTargets(nil, seen); targets != nil {
t.Fatalf("expected nil targets for nil profile, got %v", targets)
}
}
func TestExpandPhaseIPs(t *testing.T) {
scanner := &Scanner{policy: envdetect.DefaultScanPolicy()}
seen := map[string]struct{}{
"192.168.1.1": {},
}
_, subnet30, err := net.ParseCIDR("192.168.1.0/30")
if err != nil {
t.Fatalf("failed to parse subnet: %v", err)
}
_, subnet6, err := net.ParseCIDR("2001:db8::/64")
if err != nil {
t.Fatalf("failed to parse ipv6 subnet: %v", err)
}
targets, count := scanner.expandPhaseIPs(envdetect.SubnetPhase{
Subnets: []net.IPNet{*subnet30, *subnet6},
}, seen)
if count != 2 {
t.Fatalf("expected 2 subnets counted, got %d", count)
}
if len(targets) != 1 || targets[0] != "192.168.1.2" {
t.Fatalf("unexpected targets: %v", targets)
}
}
func TestShouldSkipPhase(t *testing.T) {
policy := envdetect.DefaultScanPolicy()
policy.DialTimeout = time.Second
scanner := &Scanner{policy: policy}
ctxShort, cancel := context.WithDeadline(context.Background(), time.Now().Add(500*time.Millisecond))
defer cancel()
phaseLowConfidence := envdetect.SubnetPhase{Name: "low", Confidence: 0.2}
if !scanner.shouldSkipPhase(ctxShort, phaseLowConfidence) {
t.Fatalf("expected phase to be skipped with short deadline")
}
phaseHighConfidence := envdetect.SubnetPhase{Name: "high", Confidence: 0.8}
if scanner.shouldSkipPhase(ctxShort, phaseHighConfidence) {
t.Fatalf("expected high confidence phase to run")
}
if scanner.shouldSkipPhase(context.Background(), phaseLowConfidence) {
t.Fatalf("expected phase to run without deadline")
}
}
func TestShouldSkipPhaseDefaultBudget(t *testing.T) {
policy := envdetect.DefaultScanPolicy()
policy.DialTimeout = 0
scanner := &Scanner{policy: policy}
ctxShort, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
defer cancel()
phaseLowConfidence := envdetect.SubnetPhase{Name: "low", Confidence: 0.2}
if !scanner.shouldSkipPhase(ctxShort, phaseLowConfidence) {
t.Fatalf("expected phase to be skipped with default budget")
}
}
func TestBuildEnvironmentInfoCopiesData(t *testing.T) {
profile := &envdetect.EnvironmentProfile{
Type: envdetect.DockerBridge,
Confidence: 0.75,
Warnings: []string{"warning-one"},
Metadata: map[string]string{
"container_type": "docker",
},
Phases: []envdetect.SubnetPhase{
{
Name: "phase-a",
Confidence: 0.9,
Subnets: []net.IPNet{},
},
},
}
info := buildEnvironmentInfo(profile)
if info == nil {
t.Fatal("expected environment info, got nil")
}
if info.Type != "docker_bridge" || info.Confidence != 0.75 {
t.Fatalf("unexpected info: %+v", info)
}
if len(info.Warnings) != 1 || info.Warnings[0] != "warning-one" {
t.Fatalf("unexpected warnings: %v", info.Warnings)
}
if info.Metadata["container_type"] != "docker" {
t.Fatalf("unexpected metadata: %v", info.Metadata)
}
if len(info.Phases) != 1 || info.Phases[0].Name != "phase-a" {
t.Fatalf("unexpected phase info: %v", info.Phases)
}
info.Metadata["container_type"] = "mutated"
info.Warnings[0] = "mutated-warning"
if profile.Metadata["container_type"] != "docker" {
t.Fatalf("expected metadata copy, got %v", profile.Metadata)
}
if profile.Warnings[0] != "warning-one" {
t.Fatalf("expected warnings copy, got %v", profile.Warnings)
}
}
func TestBuildEnvironmentInfoNilProfile(t *testing.T) {
if info := buildEnvironmentInfo(nil); info != nil {
t.Fatalf("expected nil info for nil profile, got %+v", info)
}
}

View File

@@ -266,3 +266,57 @@ func TestMatchesUserExclude(t *testing.T) {
})
}
}
func TestMatchesDiskExclude(t *testing.T) {
tests := []struct {
name string
device string
mountpoint string
patterns []string
expected bool
}{
{"empty patterns", "/dev/sda", "/mnt/data", nil, false},
{"mountpoint exact match", "/dev/sdb", "/mnt/backup", []string{"/mnt/backup"}, true},
{"mountpoint prefix match", "/dev/sdb", "/mnt/external-drive", []string{"/mnt/ext*"}, true},
{"device exact match", "/dev/sda", "/mnt/data", []string{"/dev/sda"}, true},
{"device name match", "/dev/nvme0n1", "/mnt/fast", []string{"nvme0n1"}, true},
{"device contains match", "/dev/nvme1n1", "/mnt/fast", []string{"*nvme*"}, true},
{"no match", "/dev/sdc", "/mnt/data", []string{"/mnt/backup", "/dev/sda"}, false},
{"device without prefix match", "sdd", "/mnt/data", []string{"sdd"}, true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := MatchesDiskExclude(tc.device, tc.mountpoint, tc.patterns)
if result != tc.expected {
t.Errorf("MatchesDiskExclude(%q, %q, %v) = %t, want %t", tc.device, tc.mountpoint, tc.patterns, result, tc.expected)
}
})
}
}
func TestMatchesDeviceExclude(t *testing.T) {
tests := []struct {
name string
device string
patterns []string
expected bool
}{
{"empty patterns", "/dev/sda", nil, false},
{"exact path match", "/dev/sda", []string{"/dev/sda"}, true},
{"exact name match", "/dev/sda", []string{"sda"}, true},
{"prefix added match", "sdb", []string{"/dev/sdb"}, true},
{"contains match", "/dev/nvme0n1", []string{"*nvme*"}, true},
{"whitespace pattern", "/dev/sdc", []string{" /dev/sdc "}, true},
{"no match", "/dev/sdd", []string{"/dev/sde"}, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := MatchesDeviceExclude(tc.device, tc.patterns)
if result != tc.expected {
t.Errorf("MatchesDeviceExclude(%q, %v) = %t, want %t", tc.device, tc.patterns, result, tc.expected)
}
})
}
}

View File

@@ -1,6 +1,50 @@
package metrics
import "testing"
import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
)
func TestRecordAlertFired(t *testing.T) {
alert := &alerts.Alert{
ID: "test-alert-1",
Level: alerts.AlertLevelWarning,
Type: "container_cpu",
StartTime: time.Now(),
}
// Should not panic
RecordAlertFired(alert)
}
func TestRecordAlertResolved(t *testing.T) {
now := time.Now()
alert := &alerts.Alert{
ID: "test-alert-2",
Level: alerts.AlertLevelWarning,
Type: "container_memory",
StartTime: now.Add(-5 * time.Minute),
LastSeen: now,
}
// Should not panic
RecordAlertResolved(alert)
}
func TestRecordAlertAcknowledged(t *testing.T) {
// Should not panic
RecordAlertAcknowledged()
}
func TestRecordAlertSuppressed(t *testing.T) {
// Should not panic with various reasons
RecordAlertSuppressed("quiet_hours")
RecordAlertSuppressed("rate_limit")
RecordAlertSuppressed("duplicate")
}
func TestMetricVectors_NotNil(t *testing.T) {
// Verify that metric vectors are properly initialized
@@ -26,3 +70,52 @@ func TestMetricVectors_NotNil(t *testing.T) {
t.Error("AlertsRateLimitedTotal should not be nil")
}
}
func TestAlertMetricsIncrements(t *testing.T) {
alert := &alerts.Alert{
ID: "metrics-test",
Level: alerts.AlertLevelCritical,
Type: "unit_test_alert",
StartTime: time.Now().Add(-time.Minute),
LastSeen: time.Now(),
}
fired := AlertsFiredTotal.WithLabelValues(string(alert.Level), alert.Type)
active := AlertsActive.WithLabelValues(string(alert.Level), alert.Type)
resolved := AlertsResolvedTotal.WithLabelValues(alert.Type)
firedBefore := testutil.ToFloat64(fired)
activeBefore := testutil.ToFloat64(active)
resolvedBefore := testutil.ToFloat64(resolved)
RecordAlertFired(alert)
if testutil.ToFloat64(fired) != firedBefore+1 {
t.Fatalf("expected fired counter increment")
}
if testutil.ToFloat64(active) != activeBefore+1 {
t.Fatalf("expected active gauge increment")
}
RecordAlertResolved(alert)
if testutil.ToFloat64(resolved) != resolvedBefore+1 {
t.Fatalf("expected resolved counter increment")
}
if testutil.ToFloat64(active) != activeBefore {
t.Fatalf("expected active gauge to return to baseline")
}
ackBefore := testutil.ToFloat64(AlertsAcknowledgedTotal)
RecordAlertAcknowledged()
if testutil.ToFloat64(AlertsAcknowledgedTotal) != ackBefore+1 {
t.Fatalf("expected acknowledged counter increment")
}
suppressed := AlertsSuppressedTotal.WithLabelValues("unit_test_reason")
suppressedBefore := testutil.ToFloat64(suppressed)
RecordAlertSuppressed("unit_test_reason")
if testutil.ToFloat64(suppressed) != suppressedBefore+1 {
t.Fatalf("expected suppressed counter increment")
}
}

View File

@@ -0,0 +1,178 @@
package metrics
import (
"math"
"testing"
"time"
)
func TestStoreQueryAllDownsampling(t *testing.T) {
dir := t.TempDir()
store, err := NewStore(DefaultConfig(dir))
if err != nil {
t.Fatalf("NewStore returned error: %v", err)
}
defer store.Close()
start := time.Unix(1000, 0)
batch := make([]bufferedMetric, 0, 20)
for i := 0; i < 10; i++ {
ts := start.Add(time.Duration(i) * time.Minute)
batch = append(batch,
bufferedMetric{resourceType: "vm", resourceID: "v1", metricType: "cpu", value: float64(i), timestamp: ts, tier: TierRaw},
bufferedMetric{resourceType: "vm", resourceID: "v1", metricType: "mem", value: float64(100 + i), timestamp: ts, tier: TierRaw},
)
}
store.writeBatch(batch)
result, err := store.QueryAll("vm", "v1", start.Add(-time.Hour), start.Add(time.Hour), 300)
if err != nil {
t.Fatalf("QueryAll downsampled failed: %v", err)
}
cpu := result["cpu"]
mem := result["mem"]
if len(cpu) != 3 || len(mem) != 3 {
t.Fatalf("expected 3 bucketed points per metric, got cpu=%d mem=%d", len(cpu), len(mem))
}
assertPoint := func(point MetricPoint, ts int64, value, min, max float64) {
t.Helper()
if point.Timestamp.Unix() != ts {
t.Fatalf("expected bucket timestamp %d, got %d", ts, point.Timestamp.Unix())
}
if math.Abs(point.Value-value) > 0.0001 {
t.Fatalf("expected value %v, got %v", value, point.Value)
}
if math.Abs(point.Min-min) > 0.0001 {
t.Fatalf("expected min %v, got %v", min, point.Min)
}
if math.Abs(point.Max-max) > 0.0001 {
t.Fatalf("expected max %v, got %v", max, point.Max)
}
}
assertPoint(cpu[0], 1050, 1.5, 0, 3)
assertPoint(cpu[1], 1350, 6, 4, 8)
assertPoint(cpu[2], 1650, 9, 9, 9)
assertPoint(mem[0], 1050, 101.5, 100, 103)
assertPoint(mem[1], 1350, 106, 104, 108)
assertPoint(mem[2], 1650, 109, 109, 109)
}
func TestStoreTierFallbacks(t *testing.T) {
dir := t.TempDir()
store, err := NewStore(DefaultConfig(dir))
if err != nil {
t.Fatalf("NewStore returned error: %v", err)
}
defer store.Close()
tests := []struct {
name string
duration time.Duration
expected []Tier
}{
{"raw", 30 * time.Minute, []Tier{TierRaw, TierMinute, TierHourly}},
{"minute", 3 * time.Hour, []Tier{TierMinute, TierRaw, TierHourly}},
{"hourly", 2 * 24 * time.Hour, []Tier{TierHourly, TierMinute, TierRaw}},
{"daily", 30 * 24 * time.Hour, []Tier{TierDaily, TierHourly, TierMinute, TierRaw}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := store.tierFallbacks(tc.duration)
if len(got) != len(tc.expected) {
t.Fatalf("expected %d tiers, got %d (%v)", len(tc.expected), len(got), got)
}
for i := range got {
if got[i] != tc.expected[i] {
t.Fatalf("expected %v, got %v", tc.expected, got)
}
}
})
}
}
func TestStoreMetadataHelpers(t *testing.T) {
dir := t.TempDir()
store, err := NewStore(DefaultConfig(dir))
if err != nil {
t.Fatalf("NewStore returned error: %v", err)
}
defer store.Close()
if value, ok := store.getMetaInt("missing"); ok {
t.Fatalf("expected missing meta to return ok=false, got %d", value)
}
if ts, ok := store.getMaxTimestampForTier(TierRaw); ok || ts != 0 {
t.Fatalf("expected no max timestamp, got %d (ok=%t)", ts, ok)
}
_, err = store.db.Exec(
`INSERT INTO metrics (resource_type, resource_id, metric_type, value, timestamp, tier) VALUES
('vm','vm-1','cpu',1.0,?, 'raw'),
('vm','vm-1','cpu',2.0,?, 'raw')`,
100, 200,
)
if err != nil {
t.Fatalf("insert metrics returned error: %v", err)
}
if ts, ok := store.getMaxTimestampForTier(TierRaw); !ok || ts != 200 {
t.Fatalf("expected max timestamp 200, got %d (ok=%t)", ts, ok)
}
}
func TestStoreQueryDownsamplingStats(t *testing.T) {
dir := t.TempDir()
store, err := NewStore(DefaultConfig(dir))
if err != nil {
t.Fatalf("NewStore returned error: %v", err)
}
defer store.Close()
start := time.Unix(1000, 0)
store.writeBatch([]bufferedMetric{
{resourceType: "vm", resourceID: "v2", metricType: "cpu", value: 10, timestamp: start, tier: TierRaw},
{resourceType: "vm", resourceID: "v2", metricType: "cpu", value: 30, timestamp: start.Add(20 * time.Second), tier: TierRaw},
{resourceType: "vm", resourceID: "v2", metricType: "cpu", value: 20, timestamp: start.Add(50 * time.Second), tier: TierRaw},
})
points, err := store.Query("vm", "v2", "cpu", start.Add(-time.Minute), start.Add(time.Minute), 120)
if err != nil {
t.Fatalf("Query downsampled failed: %v", err)
}
if len(points) != 1 {
t.Fatalf("expected 1 bucketed point, got %d", len(points))
}
point := points[0]
if point.Timestamp.Unix() != 1020 {
t.Fatalf("expected bucket timestamp 1020, got %d", point.Timestamp.Unix())
}
if point.Value != 20 || point.Min != 10 || point.Max != 30 {
t.Fatalf("unexpected stats: value=%v min=%v max=%v", point.Value, point.Min, point.Max)
}
}
func TestStoreFlushLockedDropsWhenChannelFull(t *testing.T) {
store := &Store{
config: StoreConfig{WriteBufferSize: 1},
buffer: []bufferedMetric{{resourceType: "vm", resourceID: "v3", metricType: "cpu", value: 1, timestamp: time.Now(), tier: TierRaw}},
writeCh: make(chan []bufferedMetric),
}
store.bufferMu.Lock()
store.flushLocked()
store.bufferMu.Unlock()
if len(store.buffer) != 0 {
t.Fatalf("expected buffer to be cleared, got %d", len(store.buffer))
}
if len(store.writeCh) != 0 {
t.Fatalf("expected write channel to remain empty, got %d", len(store.writeCh))
}
}

View File

@@ -417,3 +417,321 @@ func TestMailEndpointsHandleNullAndStringValues(t *testing.T) {
t.Fatalf("expected oldest age 600, got %d", queue.OldestAge.Int64())
}
}
func TestClientTokenNameIncludesUserAndRealm(t *testing.T) {
t.Parallel()
var authHeader string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api2/json/statistics/mail":
authHeader = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":{"count":1}}`)
default:
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apiuser@custom!apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
stats, err := client.GetMailStatistics(context.Background(), "")
if err != nil {
t.Fatalf("get mail statistics failed: %v", err)
}
if stats == nil || stats.Count.Float64() != 1 {
t.Fatalf("expected statistics count 1, got %+v", stats)
}
expected := "PMGAPIToken=apiuser@custom!apitoken:secret"
if authHeader != expected {
t.Fatalf("expected authorization header %q, got %q", expected, authHeader)
}
}
func TestClientRequestAuthError(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api2/json/statistics/mail":
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, "unauthorized")
default:
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
_, err = client.GetMailStatistics(context.Background(), "")
if err == nil || !strings.Contains(err.Error(), "authentication error") {
t.Fatalf("expected authentication error, got %v", err)
}
}
func TestClientRequestNonAuthError(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api2/json/statistics/mail":
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "boom")
default:
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
_, err = client.GetMailStatistics(context.Background(), "")
if err == nil {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "API error 500") {
t.Fatalf("expected API error 500, got %q", msg)
}
if strings.Contains(strings.ToLower(msg), "authentication error") {
t.Fatalf("did not expect authentication error, got %q", msg)
}
}
func TestClientGetVersionInvalidJSON(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api2/json/version":
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":`)
default:
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
if _, err := client.GetVersion(context.Background()); err == nil || !strings.Contains(err.Error(), "failed to decode response") {
t.Fatalf("expected decode error, got %v", err)
}
}
func TestClientMailCountTimespanParam(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api2/json/statistics/mailcount" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
if got := r.URL.Query().Get("timespan"); got != "3600" {
t.Fatalf("expected timespan=3600, got %q", got)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":[]}`)
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
if _, err := client.GetMailCount(context.Background(), 3600); err != nil {
t.Fatalf("GetMailCount failed: %v", err)
}
}
func TestClientClusterStatusListSingle(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api2/json/config/cluster/status" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
if got := r.URL.Query().Get("list_single_node"); got != "1" {
t.Fatalf("expected list_single_node=1, got %q", got)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":[]}`)
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
if _, err := client.GetClusterStatus(context.Background(), true); err != nil {
t.Fatalf("GetClusterStatus failed: %v", err)
}
}
func TestClientListBackupsEscapesNode(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api2/json/nodes/node/1/backup" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
if r.URL.EscapedPath() != "/api2/json/nodes/node%2F1/backup" {
t.Fatalf("expected escaped path, got %s", r.URL.EscapedPath())
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":[]}`)
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
if _, err := client.ListBackups(context.Background(), "node/1"); err != nil {
t.Fatalf("ListBackups failed: %v", err)
}
}
func TestClientGetSpamScores(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api2/json/statistics/spamscores" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":[{"level":"high","count":"2","ratio":"0.5"}]}`)
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
scores, err := client.GetSpamScores(context.Background())
if err != nil {
t.Fatalf("GetSpamScores failed: %v", err)
}
if len(scores) != 1 || scores[0].Level != "high" || scores[0].Count.Int() != 2 {
t.Fatalf("unexpected spam scores: %+v", scores)
}
}
func TestClientMailCountNoTimespanParam(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api2/json/statistics/mailcount" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
if got := r.URL.Query().Get("timespan"); got != "" {
t.Fatalf("expected no timespan param, got %q", got)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":[]}`)
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
if _, err := client.GetMailCount(context.Background(), 0); err != nil {
t.Fatalf("GetMailCount failed: %v", err)
}
}
func TestClientClusterStatusNoParam(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api2/json/config/cluster/status" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
if len(r.URL.Query()) != 0 {
t.Fatalf("expected no query params, got %v", r.URL.Query())
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"data":[]}`)
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "apitoken",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("unexpected error creating client: %v", err)
}
if _, err := client.GetClusterStatus(context.Background(), false); err != nil {
t.Fatalf("GetClusterStatus failed: %v", err)
}
}

View File

@@ -118,3 +118,33 @@ func TestClientRequest_401Unauthorized(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestClientRequest_500NonAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("boom"))
}))
defer server.Close()
client, err := NewClient(ClientConfig{
Host: server.URL,
TokenName: "user@pve!token",
TokenValue: "secret",
VerifySSL: false,
})
if err != nil {
t.Fatalf("NewClient failed: %v", err)
}
_, err = client.get(context.Background(), "/nodes")
if err == nil {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "API error 500") {
t.Fatalf("expected api error 500, got %q", msg)
}
if strings.Contains(strings.ToLower(msg), "authentication error") {
t.Fatalf("did not expect authentication error for 500, got %q", msg)
}
}

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
@@ -65,3 +66,17 @@ func TestFetchFingerprint(t *testing.T) {
t.Fatalf("unexpected fingerprint: %s", fingerprint)
}
}
func TestFetchFingerprintInvalidURL(t *testing.T) {
_, err := FetchFingerprint("http://[::1")
if err == nil || !strings.Contains(err.Error(), "failed to parse host URL") {
t.Fatalf("expected parse error, got %v", err)
}
}
func TestFetchFingerprintConnectionError(t *testing.T) {
_, err := FetchFingerprint("https://127.0.0.1:1")
if err == nil || !strings.Contains(err.Error(), "failed to connect") {
t.Fatalf("expected connection error, got %v", err)
}
}