Files
Pulse/cmd/pulse-sensor-proxy/validation_test.go
rcourtman 97672b4701 Add unit tests for validation utility functions (pulse-sensor-proxy)
Add 42 test cases for security-critical validation utility functions:

- TestStripNodeDelimiters (10 cases): IPv6 bracket handling, edge cases
- TestParseNodeIP (10 cases): IPv4/IPv6 parsing with bracket support
- TestNormalizeAllowlistEntry (11 cases): case normalization, whitespace
  handling, IPv6 full form compression
- TestIPAllowed (11 cases): CIDR matching, hosts map lookup, nil handling

These functions are used for node allowlist validation to prevent SSRF
attacks in the sensor proxy.
2025-11-30 17:04:43 +00:00

521 lines
12 KiB
Go

package main
import (
"context"
"net"
"strings"
"testing"
"time"
)
func TestSanitizeCorrelationID(t *testing.T) {
valid := sanitizeCorrelationID("550e8400-e29b-41d4-a716-446655440000")
if valid != "550e8400-e29b-41d4-a716-446655440000" {
t.Fatalf("expected valid UUID to pass through, got %s", valid)
}
invalid := sanitizeCorrelationID("not-a-uuid")
if invalid == "not-a-uuid" {
t.Fatalf("expected invalid UUID to be replaced")
}
empty := sanitizeCorrelationID("")
if empty == "" {
t.Fatalf("expected empty string to be replaced")
}
if invalid == empty {
t.Fatalf("expected regenerated UUIDs to differ")
}
}
func TestValidateNodeName(t *testing.T) {
cases := []struct {
name string
wantErr bool
desc string
}{
{name: "node-1", wantErr: false, desc: "alphanumeric"},
{name: "example.com", wantErr: false, desc: "dns hostname"},
{name: "1.2.3.4", wantErr: false, desc: "ipv4"},
{name: "2001:db8::1", wantErr: false, desc: "ipv6 compressed"},
{name: "[2001:db8::10]", wantErr: false, desc: "ipv6 bracketed"},
{name: "::1", wantErr: false, desc: "ipv6 loopback"},
{name: "::", wantErr: false, desc: "ipv6 unspecified"},
{name: "::ffff:192.0.2.1", wantErr: false, desc: "ipv4-mapped ipv6 dual stack"},
{name: "[::1]", wantErr: false, desc: "ipv6 loopback bracketed"},
{name: "fe80::1%eth0", wantErr: true, desc: "ipv6 zone identifier"},
{name: "[fe80::1%eth0]", wantErr: true, desc: "ipv6 zone identifier bracketed"},
{name: "[2001:db8::1]:22", wantErr: true, desc: "ipv6 with port suffix"},
{name: "[2001:db8::1", wantErr: true, desc: "missing closing bracket"},
{name: "2001:db8::1]", wantErr: true, desc: "missing opening bracket"},
{name: "bad host", wantErr: true, desc: "whitespace disallowed"},
{name: "-leadinghyphen", wantErr: true, desc: "leading hyphen disallowed"},
{name: "example.com:22", wantErr: true, desc: "dns name with port"},
{name: "", wantErr: true, desc: "empty string"},
{name: "example_com", wantErr: false, desc: "underscore"},
{name: "NODE123", wantErr: false, desc: "uppercase"},
{name: strings.Repeat("a", 64), wantErr: false, desc: "64 chars"},
{name: strings.Repeat("a", 65), wantErr: true, desc: "65 chars"},
{name: "senso\u200Brs", wantErr: true, desc: "zero-width space"},
{name: "node\\name", wantErr: true, desc: "backslash"},
{name: "/etc/passwd", wantErr: true, desc: "absolute path"},
{name: "node\x00", wantErr: true, desc: "null byte"},
{name: "example.com;rm", wantErr: true, desc: "semicolon"},
{name: "node$(rm)", wantErr: true, desc: "subshell"},
}
for _, tc := range cases {
tc := tc
name := tc.desc
if name == "" {
name = tc.name
}
t.Run(name, func(t *testing.T) {
err := validateNodeName(tc.name)
if tc.wantErr && err == nil {
t.Fatalf("expected error validating %q", tc.name)
}
if !tc.wantErr && err != nil {
t.Fatalf("unexpected error for %q: %v", tc.name, err)
}
})
}
}
type stubResolver struct {
ips []net.IP
err error
}
func (s stubResolver) LookupIP(ctx context.Context, host string) ([]net.IP, error) {
if s.err != nil {
return nil, s.err
}
return s.ips, nil
}
func TestNodeValidatorAllowlistHost(t *testing.T) {
v := &nodeValidator{
allowHosts: map[string]struct{}{"node-1": {}},
hasAllowlist: true,
resolver: stubResolver{},
}
if err := v.Validate(context.Background(), "node-1"); err != nil {
t.Fatalf("expected node-1 to be permitted, got error: %v", err)
}
if err := v.Validate(context.Background(), "node-2"); err == nil {
t.Fatalf("expected node-2 to be rejected without allow-list entry")
}
}
func TestNodeValidatorAllowlistCIDRWithLookup(t *testing.T) {
_, network, _ := net.ParseCIDR("10.0.0.0/24")
v := &nodeValidator{
allowHosts: make(map[string]struct{}),
allowCIDRs: []*net.IPNet{network},
hasAllowlist: true,
resolver: stubResolver{
ips: []net.IP{net.ParseIP("10.0.0.5")},
},
}
if err := v.Validate(context.Background(), "worker.local"); err != nil {
t.Fatalf("expected worker.local to resolve into allowed CIDR: %v", err)
}
}
func TestNodeValidatorClusterCaching(t *testing.T) {
current := time.Now()
fetches := 0
v := &nodeValidator{
clusterEnabled: true,
clusterFetcher: func() ([]string, error) {
fetches++
return []string{"10.0.0.9"}, nil
},
cacheTTL: nodeValidatorCacheTTL,
clock: func() time.Time {
return current
},
}
if err := v.Validate(context.Background(), "10.0.0.9"); err != nil {
t.Fatalf("expected node to be allowed via cluster membership: %v", err)
}
if fetches != 1 {
t.Fatalf("expected initial cluster fetch, got %d", fetches)
}
current = current.Add(30 * time.Second)
if err := v.Validate(context.Background(), "10.0.0.9"); err != nil {
t.Fatalf("expected cached cluster membership to allow node: %v", err)
}
if fetches != 1 {
t.Fatalf("expected cache hit to avoid new fetch, got %d fetches", fetches)
}
current = current.Add(nodeValidatorCacheTTL + time.Second)
if err := v.Validate(context.Background(), "10.0.0.9"); err != nil {
t.Fatalf("expected refreshed cluster membership to allow node: %v", err)
}
if fetches != 2 {
t.Fatalf("expected cache expiry to trigger new fetch, got %d", fetches)
}
}
func TestNodeValidatorClusterResolvesHostIPs(t *testing.T) {
v := &nodeValidator{
clusterEnabled: true,
clusterFetcher: func() ([]string, error) {
return []string{"worker.local"}, nil
},
resolver: stubResolver{
ips: []net.IP{net.ParseIP("10.0.0.5")},
},
}
if err := v.Validate(context.Background(), "10.0.0.5"); err != nil {
t.Fatalf("expected cluster hostname resolution to permit node: %v", err)
}
}
func TestNodeValidatorStrictNoSources(t *testing.T) {
v := &nodeValidator{
strict: true,
}
if err := v.Validate(context.Background(), "node-1"); err == nil {
t.Fatalf("expected strict mode without sources to reject nodes")
}
}
func TestStripNodeDelimiters(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "ipv6 with brackets",
input: "[2001:db8::1]",
expected: "2001:db8::1",
},
{
name: "ipv6 loopback bracketed",
input: "[::1]",
expected: "::1",
},
{
name: "ipv4 no brackets",
input: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "hostname no brackets",
input: "node-1.example.com",
expected: "node-1.example.com",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "only opening bracket",
input: "[2001:db8::1",
expected: "[2001:db8::1",
},
{
name: "only closing bracket",
input: "2001:db8::1]",
expected: "2001:db8::1]",
},
{
name: "brackets with single char",
input: "[a]",
expected: "a",
},
{
name: "empty brackets",
input: "[]",
expected: "[]",
},
{
name: "single bracket char",
input: "[",
expected: "[",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := stripNodeDelimiters(tc.input)
if result != tc.expected {
t.Errorf("stripNodeDelimiters(%q) = %q, want %q", tc.input, result, tc.expected)
}
})
}
}
func TestParseNodeIP(t *testing.T) {
tests := []struct {
name string
input string
expectNil bool
expectedIP string
}{
{
name: "ipv4",
input: "192.168.1.1",
expectNil: false,
expectedIP: "192.168.1.1",
},
{
name: "ipv4 with whitespace",
input: " 10.0.0.1 ",
expectNil: false,
expectedIP: "10.0.0.1",
},
{
name: "ipv6",
input: "2001:db8::1",
expectNil: false,
expectedIP: "2001:db8::1",
},
{
name: "ipv6 bracketed",
input: "[2001:db8::1]",
expectNil: false,
expectedIP: "2001:db8::1",
},
{
name: "ipv6 loopback",
input: "::1",
expectNil: false,
expectedIP: "::1",
},
{
name: "ipv6 loopback bracketed",
input: "[::1]",
expectNil: false,
expectedIP: "::1",
},
{
name: "hostname",
input: "node-1.example.com",
expectNil: true,
},
{
name: "empty string",
input: "",
expectNil: true,
},
{
name: "invalid ip",
input: "999.999.999.999",
expectNil: true,
},
{
name: "partial ipv4",
input: "192.168.1",
expectNil: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := parseNodeIP(tc.input)
if tc.expectNil {
if result != nil {
t.Errorf("parseNodeIP(%q) = %v, want nil", tc.input, result)
}
} else {
if result == nil {
t.Errorf("parseNodeIP(%q) = nil, want %s", tc.input, tc.expectedIP)
} else if result.String() != tc.expectedIP {
t.Errorf("parseNodeIP(%q) = %v, want %s", tc.input, result, tc.expectedIP)
}
}
})
}
}
func TestNormalizeAllowlistEntry(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "ipv4",
input: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "ipv4 with whitespace",
input: " 10.0.0.1 ",
expected: "10.0.0.1",
},
{
name: "ipv6",
input: "2001:db8::1",
expected: "2001:db8::1",
},
{
name: "ipv6 bracketed",
input: "[2001:db8::1]",
expected: "2001:db8::1",
},
{
name: "hostname lowercase",
input: "node-1.example.com",
expected: "node-1.example.com",
},
{
name: "hostname uppercase normalized",
input: "NODE-1.EXAMPLE.COM",
expected: "node-1.example.com",
},
{
name: "hostname mixed case",
input: "Node-1.Example.Com",
expected: "node-1.example.com",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "ipv6 loopback",
input: "::1",
expected: "::1",
},
{
name: "ipv6 full form normalized",
input: "2001:0db8:0000:0000:0000:0000:0000:0001",
expected: "2001:db8::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := normalizeAllowlistEntry(tc.input)
if result != tc.expected {
t.Errorf("normalizeAllowlistEntry(%q) = %q, want %q", tc.input, result, tc.expected)
}
})
}
}
func TestIPAllowed(t *testing.T) {
// Setup test CIDRs
_, cidr10, _ := net.ParseCIDR("10.0.0.0/8")
_, cidr192, _ := net.ParseCIDR("192.168.1.0/24")
_, cidr172, _ := net.ParseCIDR("172.16.0.0/12")
tests := []struct {
name string
ip net.IP
hosts map[string]struct{}
cidrs []*net.IPNet
expected bool
}{
{
name: "nil ip",
ip: nil,
hosts: map[string]struct{}{"10.0.0.1": {}},
cidrs: nil,
expected: false,
},
{
name: "ip in hosts map",
ip: net.ParseIP("192.168.1.100"),
hosts: map[string]struct{}{"192.168.1.100": {}},
cidrs: nil,
expected: true,
},
{
name: "ip not in hosts map",
ip: net.ParseIP("192.168.1.100"),
hosts: map[string]struct{}{"192.168.1.200": {}},
cidrs: nil,
expected: false,
},
{
name: "ip in cidr range",
ip: net.ParseIP("10.1.2.3"),
hosts: nil,
cidrs: []*net.IPNet{cidr10},
expected: true,
},
{
name: "ip not in cidr range",
ip: net.ParseIP("11.0.0.1"),
hosts: nil,
cidrs: []*net.IPNet{cidr10},
expected: false,
},
{
name: "ip in second cidr",
ip: net.ParseIP("192.168.1.50"),
hosts: nil,
cidrs: []*net.IPNet{cidr10, cidr192},
expected: true,
},
{
name: "ip matches hosts not cidrs",
ip: net.ParseIP("8.8.8.8"),
hosts: map[string]struct{}{"8.8.8.8": {}},
cidrs: []*net.IPNet{cidr10},
expected: true,
},
{
name: "ip matches cidrs not hosts",
ip: net.ParseIP("172.20.0.1"),
hosts: map[string]struct{}{"8.8.8.8": {}},
cidrs: []*net.IPNet{cidr172},
expected: true,
},
{
name: "empty hosts and cidrs",
ip: net.ParseIP("192.168.1.1"),
hosts: nil,
cidrs: nil,
expected: false,
},
{
name: "ipv6 in hosts",
ip: net.ParseIP("2001:db8::1"),
hosts: map[string]struct{}{"2001:db8::1": {}},
cidrs: nil,
expected: true,
},
{
name: "nil hosts map with ip match in cidr",
ip: net.ParseIP("10.255.255.255"),
hosts: nil,
cidrs: []*net.IPNet{cidr10},
expected: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := ipAllowed(tc.ip, tc.hosts, tc.cidrs)
if result != tc.expected {
t.Errorf("ipAllowed(%v, hosts, cidrs) = %v, want %v", tc.ip, result, tc.expected)
}
})
}
}