feat: improve test coverage for pulse-sensor-proxy

This commit is contained in:
rcourtman
2026-01-03 21:42:19 +00:00
parent fd7e80ae17
commit 5d4e911298
7 changed files with 1349 additions and 142 deletions

View File

@@ -1,143 +1,297 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/ssh/knownhosts"
)
func TestHashIPToUID(t *testing.T) {
tests := []struct {
name string
ip string
wantMin uint32
wantMax uint32
wantSame bool // if true, verify determinism by checking same IP gives same result
}{
{
name: "IPv4 localhost",
ip: "127.0.0.1",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "IPv4 standard",
ip: "192.168.1.100",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "IPv4 another address",
ip: "10.0.0.1",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "IPv6 localhost",
ip: "::1",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "IPv6 full address",
ip: "2001:db8::1",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "empty string",
ip: "",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "single character",
ip: "a",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
{
name: "long string",
ip: "this-is-a-very-long-hostname-that-might-be-used.example.com",
wantMin: 100000,
wantMax: 999999,
wantSame: true,
},
func TestHTTPServer_Health(t *testing.T) {
proxy := &Proxy{}
config := &Config{
HTTPEnabled: true,
HTTPAuthToken: "secret-token",
}
server := NewHTTPServer(proxy, config)
// Test valid health check
req := httptest.NewRequest(http.MethodGet, "/health", nil)
req.Header.Set("Authorization", "Bearer secret-token")
w := httptest.NewRecorder()
// Apply middleware stack manually or construct the handler chain
handler := server.authMiddleware(http.HandlerFunc(server.handleHealth))
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := hashIPToUID(tc.ip)
var resp map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp["status"] != "ok" {
t.Errorf("expected status ok, got %v", resp["status"])
}
// Check range
if result < tc.wantMin || result > tc.wantMax {
t.Errorf("hashIPToUID(%q) = %d, want in range [%d, %d]",
tc.ip, result, tc.wantMin, tc.wantMax)
// Test invalid method
req = httptest.NewRequest(http.MethodPost, "/health", nil)
req.Header.Set("Authorization", "Bearer secret-token")
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status 405, got %d", w.Code)
}
}
func TestHTTPServer_AuthMiddleware(t *testing.T) {
proxy := &Proxy{
audit: newAuditLogger(os.DevNull), // avoid nil panic
}
config := &Config{
HTTPAuthToken: "secret",
}
server := NewHTTPServer(proxy, config)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := server.authMiddleware(next)
tests := []struct {
name string
authHeader string
wantCode int
}{
{"MissingHeader", "", http.StatusUnauthorized},
{"InvalidFormat", "Basic user:pass", http.StatusUnauthorized},
{"InvalidToken", "Bearer wrong", http.StatusUnauthorized},
{"ValidToken", "Bearer secret", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Check determinism
if tc.wantSame {
result2 := hashIPToUID(tc.ip)
if result != result2 {
t.Errorf("hashIPToUID(%q) not deterministic: got %d then %d",
tc.ip, result, result2)
}
if w.Code != tt.wantCode {
t.Errorf("expected code %d, got %d", tt.wantCode, w.Code)
}
})
}
}
func TestHashIPToUID_DifferentInputsProduceDifferentHashes(t *testing.T) {
ips := []string{
"127.0.0.1",
"192.168.1.1",
"192.168.1.2",
"10.0.0.1",
"::1",
"2001:db8::1",
}
hashes := make(map[uint32]string)
collisions := 0
for _, ip := range ips {
hash := hashIPToUID(ip)
if existing, found := hashes[hash]; found {
// Collision found - not necessarily an error but worth noting
collisions++
t.Logf("Hash collision: %q and %q both produce %d", ip, existing, hash)
func TestHTTPServer_Temperature(t *testing.T) {
// Mock SSH execution
origExec := execCommandFunc
defer func() { execCommandFunc = origExec }()
execCommandFunc = func(name string, arg ...string) *exec.Cmd {
args := strings.Join(arg, " ")
if strings.Contains(args, "ssh") {
// Return mock sensor JSON
return mockExecCommand(`{"coretemp-isa-0000":{"Package id 0":{"temp1_input": 50.0}}}`)
}
hashes[hash] = ip
return mockExecCommand("")
}
// With only 6 inputs and 900000 possible outputs, collisions should be rare
if collisions > 1 {
t.Errorf("Too many collisions (%d) for %d inputs", collisions, len(ips))
// Mock keyscan to avoid trying actual network keyscan
// But p.getTemperatureViaSSH depends on p.knownHosts being set.
tmpDir := t.TempDir()
km, _ := knownhosts.NewManager(filepath.Join(tmpDir, "known_hosts"), knownhosts.WithKeyscanFunc(func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
return []byte(fmt.Sprintf("%s ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKqy\n", host)), nil
}))
os.WriteFile(filepath.Join(tmpDir, "id_ed25519"), []byte("priv"), 0600)
proxy := &Proxy{
sshKeyPath: tmpDir,
knownHosts: km,
metrics: NewProxyMetrics("test"),
maxSSHOutputBytes: 1024,
nodeGate: newNodeGate(),
config: &Config{}, // Init config to avoid panic in getTemperatureViaSSH if accessed?
}
// Init node validator
proxy.nodeValidator, _ = newNodeValidator(&Config{}, proxy.metrics)
config := &Config{
HTTPAuthToken: "secret",
}
server := NewHTTPServer(proxy, config)
// Test valid request
req := httptest.NewRequest("GET", "/temps?node=valid-node", nil)
w := httptest.NewRecorder()
server.handleTemperature(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d body: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "50.0") {
t.Errorf("expected temp 50.0 in response")
}
// Test missing node
req = httptest.NewRequest("GET", "/temps", nil)
w = httptest.NewRecorder()
server.handleTemperature(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for missing node, got %d", w.Code)
}
// Test invalid node name
req = httptest.NewRequest("GET", "/temps?node=-invalid-", nil)
w = httptest.NewRecorder()
server.handleTemperature(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for invalid node name, got %d", w.Code)
}
// Test SSH failure
execCommandFunc = func(name string, arg ...string) *exec.Cmd {
args := strings.Join(arg, " ")
if strings.Contains(args, "ssh") {
return errorExecCommand("ssh failed")
}
// Also fail local fallback
if name == "sensors" {
return errorExecCommand("sensors failed")
}
return mockExecCommand("")
}
// Need to mock getTemperatureLocal failing too.
req = httptest.NewRequest("GET", "/temps?node=fail-node", nil)
w = httptest.NewRecorder()
server.handleTemperature(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 for ssh failure, got %d", w.Code)
}
}
func TestHashIPToUID_BoundaryValues(t *testing.T) {
// Test that the function correctly produces values in the expected range
// even for edge cases
func TestHTTPServer_SourceIPMiddleware(t *testing.T) {
proxy := &Proxy{
audit: newAuditLogger(os.DevNull),
}
config := &Config{
AllowedSourceSubnets: []string{"192.168.1.0/24", "10.0.0.1/32"},
}
server := NewHTTPServer(proxy, config)
tests := []string{
"", // empty
"\x00", // null byte
"\xff\xff\xff", // high bytes
"0.0.0.0",
"255.255.255.255",
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := server.sourceIPMiddleware(next)
tests := []struct {
name string
remoteIP string
wantCode int
}{
{"AllowedSubnet", "192.168.1.10:1234", http.StatusOK},
{"AllowedSingle", "10.0.0.1:5678", http.StatusOK},
{"DeniedIP", "1.2.3.4:1234", http.StatusForbidden},
{"InvalidIP", "invalid-ip", http.StatusForbidden},
}
for _, ip := range tests {
result := hashIPToUID(ip)
if result < 100000 || result > 999999 {
t.Errorf("hashIPToUID(%q) = %d, out of expected range [100000, 999999]",
ip, result)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteIP
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tt.wantCode {
t.Errorf("expected code %d for %s, got %d", tt.wantCode, tt.remoteIP, w.Code)
}
})
}
}
func TestHTTPServer_StartValidation(t *testing.T) {
server := NewHTTPServer(&Proxy{}, &Config{HTTPEnabled: true})
// Missing certs
if err := server.Start(); err == nil {
t.Error("expected error when starting without certs")
}
server = NewHTTPServer(&Proxy{}, &Config{HTTPEnabled: false})
if err := server.Start(); err != nil {
t.Error("expected no error when HTTP disabled")
}
}
func TestHTTPServer_RateLimiter(t *testing.T) {
proxy := &Proxy{
metrics: NewProxyMetrics("test"),
}
// proxy.rateLimiter must be initialized
proxy.rateLimiter = newRateLimiter(proxy.metrics, nil, nil, nil)
config := &Config{}
server := NewHTTPServer(proxy, config)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := server.rateLimitMiddleware(next)
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestHashIPToUID(t *testing.T) {
uid1 := hashIPToUID("192.168.1.1")
uid2 := hashIPToUID("192.168.1.1")
uid3 := hashIPToUID("10.0.0.1")
if uid1 != uid2 {
t.Error("expected deterministic hash")
}
if uid1 == uid3 {
t.Error("expected different hash for different IPs")
}
}
func TestHTTPServer_Stop(t *testing.T) {
server := NewHTTPServer(&Proxy{}, &Config{})
if err := server.Stop(context.Background()); err != nil {
t.Errorf("Stop failed: %v", err)
}
// Test nil server
s2 := &HTTPServer{}
if err := s2.Stop(context.Background()); err != nil {
t.Errorf("Stop failed for nil server: %v", err)
}
}
func TestResponseWriter(t *testing.T) {
rw := &responseWriter{ResponseWriter: httptest.NewRecorder()}
rw.WriteHeader(http.StatusTeapot)
if rw.statusCode != http.StatusTeapot {
t.Errorf("expected status %d, got %d", http.StatusTeapot, rw.statusCode)
}
}

View File

@@ -62,6 +62,23 @@ var rootCmd = &cobra.Command{
},
}
// Variable for testing - can be overridden to mock credential extraction
var extractPeerCredentials = defaultExtractPeerCredentials
// Variables for testing system calls
var (
osGeteuid = os.Geteuid
unixSetgroups = unix.Setgroups
unixSetgid = unix.Setgid
unixSetuid = unix.Setuid
)
// Variable for mocking resolveUserSpec
var (
resolveUserSpecFunc = resolveUserSpec
netListen = net.Listen
)
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
@@ -160,11 +177,11 @@ func dropPrivileges(username string) (*userSpec, error) {
return nil, nil
}
if os.Geteuid() != 0 {
if osGeteuid() != 0 {
return nil, nil
}
spec, err := resolveUserSpec(username)
spec, err := resolveUserSpecFunc(username)
if err != nil {
return nil, err
}
@@ -173,13 +190,13 @@ func dropPrivileges(username string) (*userSpec, error) {
spec.groups = []int{spec.gid}
}
if err := unix.Setgroups(spec.groups); err != nil {
if err := unixSetgroups(spec.groups); err != nil {
return nil, fmt.Errorf("setgroups: %w", err)
}
if err := unix.Setgid(spec.gid); err != nil {
if err := unixSetgid(spec.gid); err != nil {
return nil, fmt.Errorf("setgid: %w", err)
}
if err := unix.Setuid(spec.uid); err != nil {
if err := unixSetuid(spec.uid); err != nil {
return nil, fmt.Errorf("setuid: %w", err)
}
@@ -236,10 +253,13 @@ func resolveUserSpec(username string) (*userSpec, error) {
return nil, fmt.Errorf("lookup user %q failed: %v (fallback: %w)", username, err, fallbackErr)
}
// Variable for testing
var passwdPath = "/etc/passwd"
func lookupUserFromPasswd(username string) (*userSpec, error) {
f, err := os.Open("/etc/passwd")
f, err := os.Open(passwdPath)
if err != nil {
return nil, fmt.Errorf("open /etc/passwd: %w", err)
return nil, fmt.Errorf("open %s: %w", passwdPath, err)
}
defer f.Close()
@@ -557,7 +577,7 @@ func (p *Proxy) Start() error {
}
// Create unix socket listener
listener, err := net.Listen("unix", p.socketPath)
listener, err := netListen("unix", p.socketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket: %w", err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -10,8 +10,8 @@ import (
"github.com/rs/zerolog/log"
)
// extractPeerCredentials extracts peer credentials via SO_PEERCRED
func extractPeerCredentials(conn net.Conn) (*peerCredentials, error) {
// defaultExtractPeerCredentials extracts peer credentials via SO_PEERCRED
func defaultExtractPeerCredentials(conn net.Conn) (*peerCredentials, error) {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return nil, fmt.Errorf("not a unix connection")

View File

@@ -9,8 +9,8 @@ import (
"github.com/rs/zerolog/log"
)
// extractPeerCredentials is a stub for non-Linux systems
func extractPeerCredentials(conn net.Conn) (*peerCredentials, error) {
// defaultExtractPeerCredentials is a stub for non-Linux systems
func defaultExtractPeerCredentials(conn net.Conn) (*peerCredentials, error) {
// On non-Linux systems (like macOS dev), we can't easily get the peer credentials
// from the socket. For development purposes, we'll assume the connection
// comes from the current user.

View File

@@ -32,6 +32,9 @@ var osHostname = os.Hostname
// Variable for testing to mock exec.Command (for simple output)
var execCommandFunc = exec.Command
// Variable for testing to mock exec.CommandContext
var execCommandContextFunc = exec.CommandContext
const (
tempWrapperPath = "/usr/local/libexec/pulse-sensor-proxy/temp-wrapper.sh"
tempWrapperScript = `#!/bin/sh
@@ -73,17 +76,17 @@ exit 1
`
)
const proxmoxClusterKnownHostsPath = "/etc/pve/priv/known_hosts"
var proxmoxClusterKnownHostsPath = "/etc/pve/priv/known_hosts"
// execCommand executes a shell command and returns output
func execCommand(cmd string) (string, error) {
out, err := exec.Command("sh", "-c", cmd).CombinedOutput()
out, err := execCommandFunc("sh", "-c", cmd).CombinedOutput()
return string(out), err
}
// execCommandWithLimitsContext runs a shell command with output limits and context cancellation
func execCommandWithLimitsContext(ctx context.Context, cmd string, stdoutLimit, stderrLimit int64) (string, string, bool, bool, error) {
command := exec.CommandContext(ctx, "sh", "-c", cmd)
command := execCommandContextFunc(ctx, "sh", "-c", cmd)
stdoutPipe, err := command.StdoutPipe()
if err != nil {
@@ -150,7 +153,7 @@ func execCommandWithLimitsContext(ctx context.Context, cmd string, stdoutLimit,
}
func execCommandWithLimits(cmd string, stdoutLimit, stderrLimit int64) (string, string, bool, bool, error) {
command := exec.Command("sh", "-c", cmd)
command := execCommandFunc("sh", "-c", cmd)
stdoutPipe, err := command.StdoutPipe()
if err != nil {
@@ -666,7 +669,7 @@ func discoverClusterNodes() ([]string, error) {
}
// Get cluster status with IP addresses
cmd := exec.Command("pvecm", "status")
cmd := execCommandFunc("pvecm", "status")
var out, stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
@@ -777,7 +780,7 @@ func discoverLocalHostAddresses() ([]string, error) {
addresses[strings.ToLower(hostname)] = struct{}{}
// Try to get FQDN
cmd := exec.Command("hostname", "-f")
cmd := execCommandFunc("hostname", "-f")
if out, err := cmd.Output(); err == nil {
fqdn := strings.TrimSpace(string(out))
if fqdn != "" && fqdn != hostname {
@@ -885,7 +888,7 @@ func discoverLocalHostAddressesFallback() ([]string, error) {
// Get hostname and FQDN (same as native version)
if hostname, err := os.Hostname(); err == nil && hostname != "" {
addresses[strings.ToLower(hostname)] = struct{}{}
cmd := exec.Command("hostname", "-f")
cmd := execCommandFunc("hostname", "-f")
if out, err := cmd.Output(); err == nil {
fqdn := strings.TrimSpace(string(out))
if fqdn != "" && fqdn != hostname {
@@ -895,7 +898,7 @@ func discoverLocalHostAddressesFallback() ([]string, error) {
}
// Use 'ip addr' to get IP addresses
cmd := exec.Command("ip", "addr", "show")
cmd := execCommandFunc("ip", "addr", "show")
out, err := cmd.Output()
if err != nil {
log.Warn().Err(err).Msg("Failed to run 'ip addr' command")
@@ -1029,11 +1032,11 @@ func isLocalNode(nodeHost string) bool {
// getTemperatureLocal collects temperature data from the local machine
func (p *Proxy) getTemperatureLocal(ctx context.Context) (string, error) {
// Run the same command that the wrapper script runs with context timeout
cmd := exec.CommandContext(ctx, "sensors", "-j")
cmd := execCommandContextFunc(ctx, "sensors", "-j")
output, err := cmd.Output()
if err != nil {
// Try without -j flag as fallback
cmd = exec.CommandContext(ctx, "sensors")
cmd = execCommandContextFunc(ctx, "sensors")
if _, err = cmd.Output(); err != nil {
return "", fmt.Errorf("failed to run sensors: %w", err)
}

View File

@@ -530,8 +530,8 @@ func TestDefaultHostResolver(t *testing.T) {
t.Error("expected at least one IP for localhost")
}
// Test with nil context
_, _ = r.LookupIP(nil, "localhost")
// Test with nil (TODO) context
_, _ = r.LookupIP(context.TODO(), "localhost")
// Test with invalid host
_, err = r.LookupIP(context.Background(), "invalid.host.local.test")