diff --git a/cmd/pulse-sensor-proxy/http_server_test.go b/cmd/pulse-sensor-proxy/http_server_test.go index 113ca527b..6753b40fa 100644 --- a/cmd/pulse-sensor-proxy/http_server_test.go +++ b/cmd/pulse-sensor-proxy/http_server_test.go @@ -1,143 +1,297 @@ package main import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ssh/knownhosts" ) -func TestHashIPToUID(t *testing.T) { - tests := []struct { - name string - ip string - wantMin uint32 - wantMax uint32 - wantSame bool // if true, verify determinism by checking same IP gives same result - }{ - { - name: "IPv4 localhost", - ip: "127.0.0.1", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "IPv4 standard", - ip: "192.168.1.100", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "IPv4 another address", - ip: "10.0.0.1", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "IPv6 localhost", - ip: "::1", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "IPv6 full address", - ip: "2001:db8::1", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "empty string", - ip: "", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "single character", - ip: "a", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, - { - name: "long string", - ip: "this-is-a-very-long-hostname-that-might-be-used.example.com", - wantMin: 100000, - wantMax: 999999, - wantSame: true, - }, +func TestHTTPServer_Health(t *testing.T) { + proxy := &Proxy{} + config := &Config{ + HTTPEnabled: true, + HTTPAuthToken: "secret-token", + } + server := NewHTTPServer(proxy, config) + + // Test valid health check + req := httptest.NewRequest(http.MethodGet, "/health", nil) + req.Header.Set("Authorization", "Bearer secret-token") + w := httptest.NewRecorder() + + // Apply middleware stack manually or construct the handler chain + handler := server.authMiddleware(http.HandlerFunc(server.handleHealth)) + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := hashIPToUID(tc.ip) + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + if resp["status"] != "ok" { + t.Errorf("expected status ok, got %v", resp["status"]) + } - // Check range - if result < tc.wantMin || result > tc.wantMax { - t.Errorf("hashIPToUID(%q) = %d, want in range [%d, %d]", - tc.ip, result, tc.wantMin, tc.wantMax) + // Test invalid method + req = httptest.NewRequest(http.MethodPost, "/health", nil) + req.Header.Set("Authorization", "Bearer secret-token") + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", w.Code) + } +} + +func TestHTTPServer_AuthMiddleware(t *testing.T) { + proxy := &Proxy{ + audit: newAuditLogger(os.DevNull), // avoid nil panic + } + config := &Config{ + HTTPAuthToken: "secret", + } + server := NewHTTPServer(proxy, config) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := server.authMiddleware(next) + + tests := []struct { + name string + authHeader string + wantCode int + }{ + {"MissingHeader", "", http.StatusUnauthorized}, + {"InvalidFormat", "Basic user:pass", http.StatusUnauthorized}, + {"InvalidToken", "Bearer wrong", http.StatusUnauthorized}, + {"ValidToken", "Bearer secret", http.StatusOK}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) } + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) - // Check determinism - if tc.wantSame { - result2 := hashIPToUID(tc.ip) - if result != result2 { - t.Errorf("hashIPToUID(%q) not deterministic: got %d then %d", - tc.ip, result, result2) - } + if w.Code != tt.wantCode { + t.Errorf("expected code %d, got %d", tt.wantCode, w.Code) } }) } } -func TestHashIPToUID_DifferentInputsProduceDifferentHashes(t *testing.T) { - ips := []string{ - "127.0.0.1", - "192.168.1.1", - "192.168.1.2", - "10.0.0.1", - "::1", - "2001:db8::1", - } - - hashes := make(map[uint32]string) - collisions := 0 - - for _, ip := range ips { - hash := hashIPToUID(ip) - if existing, found := hashes[hash]; found { - // Collision found - not necessarily an error but worth noting - collisions++ - t.Logf("Hash collision: %q and %q both produce %d", ip, existing, hash) +func TestHTTPServer_Temperature(t *testing.T) { + // Mock SSH execution + origExec := execCommandFunc + defer func() { execCommandFunc = origExec }() + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + args := strings.Join(arg, " ") + if strings.Contains(args, "ssh") { + // Return mock sensor JSON + return mockExecCommand(`{"coretemp-isa-0000":{"Package id 0":{"temp1_input": 50.0}}}`) } - hashes[hash] = ip + return mockExecCommand("") } - // With only 6 inputs and 900000 possible outputs, collisions should be rare - if collisions > 1 { - t.Errorf("Too many collisions (%d) for %d inputs", collisions, len(ips)) + // Mock keyscan to avoid trying actual network keyscan + // But p.getTemperatureViaSSH depends on p.knownHosts being set. + + tmpDir := t.TempDir() + km, _ := knownhosts.NewManager(filepath.Join(tmpDir, "known_hosts"), knownhosts.WithKeyscanFunc(func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + return []byte(fmt.Sprintf("%s ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKqy\n", host)), nil + })) + os.WriteFile(filepath.Join(tmpDir, "id_ed25519"), []byte("priv"), 0600) + + proxy := &Proxy{ + sshKeyPath: tmpDir, + knownHosts: km, + metrics: NewProxyMetrics("test"), + maxSSHOutputBytes: 1024, + nodeGate: newNodeGate(), + config: &Config{}, // Init config to avoid panic in getTemperatureViaSSH if accessed? + } + // Init node validator + proxy.nodeValidator, _ = newNodeValidator(&Config{}, proxy.metrics) + + config := &Config{ + HTTPAuthToken: "secret", + } + server := NewHTTPServer(proxy, config) + + // Test valid request + req := httptest.NewRequest("GET", "/temps?node=valid-node", nil) + w := httptest.NewRecorder() + server.handleTemperature(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d body: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "50.0") { + t.Errorf("expected temp 50.0 in response") + } + + // Test missing node + req = httptest.NewRequest("GET", "/temps", nil) + w = httptest.NewRecorder() + server.handleTemperature(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for missing node, got %d", w.Code) + } + + // Test invalid node name + req = httptest.NewRequest("GET", "/temps?node=-invalid-", nil) + w = httptest.NewRecorder() + server.handleTemperature(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid node name, got %d", w.Code) + } + + // Test SSH failure + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + args := strings.Join(arg, " ") + if strings.Contains(args, "ssh") { + return errorExecCommand("ssh failed") + } + // Also fail local fallback + if name == "sensors" { + return errorExecCommand("sensors failed") + } + return mockExecCommand("") + } + // Need to mock getTemperatureLocal failing too. + + req = httptest.NewRequest("GET", "/temps?node=fail-node", nil) + w = httptest.NewRecorder() + server.handleTemperature(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500 for ssh failure, got %d", w.Code) } } -func TestHashIPToUID_BoundaryValues(t *testing.T) { - // Test that the function correctly produces values in the expected range - // even for edge cases +func TestHTTPServer_SourceIPMiddleware(t *testing.T) { + proxy := &Proxy{ + audit: newAuditLogger(os.DevNull), + } + config := &Config{ + AllowedSourceSubnets: []string{"192.168.1.0/24", "10.0.0.1/32"}, + } + server := NewHTTPServer(proxy, config) - tests := []string{ - "", // empty - "\x00", // null byte - "\xff\xff\xff", // high bytes - "0.0.0.0", - "255.255.255.255", + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := server.sourceIPMiddleware(next) + + tests := []struct { + name string + remoteIP string + wantCode int + }{ + {"AllowedSubnet", "192.168.1.10:1234", http.StatusOK}, + {"AllowedSingle", "10.0.0.1:5678", http.StatusOK}, + {"DeniedIP", "1.2.3.4:1234", http.StatusForbidden}, + {"InvalidIP", "invalid-ip", http.StatusForbidden}, } - for _, ip := range tests { - result := hashIPToUID(ip) - if result < 100000 || result > 999999 { - t.Errorf("hashIPToUID(%q) = %d, out of expected range [100000, 999999]", - ip, result) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tt.remoteIP + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != tt.wantCode { + t.Errorf("expected code %d for %s, got %d", tt.wantCode, tt.remoteIP, w.Code) + } + }) + } +} + +func TestHTTPServer_StartValidation(t *testing.T) { + server := NewHTTPServer(&Proxy{}, &Config{HTTPEnabled: true}) + // Missing certs + if err := server.Start(); err == nil { + t.Error("expected error when starting without certs") + } + + server = NewHTTPServer(&Proxy{}, &Config{HTTPEnabled: false}) + if err := server.Start(); err != nil { + t.Error("expected no error when HTTP disabled") + } +} + +func TestHTTPServer_RateLimiter(t *testing.T) { + proxy := &Proxy{ + metrics: NewProxyMetrics("test"), + } + // proxy.rateLimiter must be initialized + proxy.rateLimiter = newRateLimiter(proxy.metrics, nil, nil, nil) + + config := &Config{} + server := NewHTTPServer(proxy, config) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := server.rateLimitMiddleware(next) + + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +func TestHashIPToUID(t *testing.T) { + uid1 := hashIPToUID("192.168.1.1") + uid2 := hashIPToUID("192.168.1.1") + uid3 := hashIPToUID("10.0.0.1") + + if uid1 != uid2 { + t.Error("expected deterministic hash") + } + if uid1 == uid3 { + t.Error("expected different hash for different IPs") + } +} + +func TestHTTPServer_Stop(t *testing.T) { + server := NewHTTPServer(&Proxy{}, &Config{}) + if err := server.Stop(context.Background()); err != nil { + t.Errorf("Stop failed: %v", err) + } + // Test nil server + s2 := &HTTPServer{} + if err := s2.Stop(context.Background()); err != nil { + t.Errorf("Stop failed for nil server: %v", err) + } +} + +func TestResponseWriter(t *testing.T) { + rw := &responseWriter{ResponseWriter: httptest.NewRecorder()} + rw.WriteHeader(http.StatusTeapot) + if rw.statusCode != http.StatusTeapot { + t.Errorf("expected status %d, got %d", http.StatusTeapot, rw.statusCode) } } diff --git a/cmd/pulse-sensor-proxy/main.go b/cmd/pulse-sensor-proxy/main.go index adbd9b53e..385b30940 100644 --- a/cmd/pulse-sensor-proxy/main.go +++ b/cmd/pulse-sensor-proxy/main.go @@ -62,6 +62,23 @@ var rootCmd = &cobra.Command{ }, } +// Variable for testing - can be overridden to mock credential extraction +var extractPeerCredentials = defaultExtractPeerCredentials + +// Variables for testing system calls +var ( + osGeteuid = os.Geteuid + unixSetgroups = unix.Setgroups + unixSetgid = unix.Setgid + unixSetuid = unix.Setuid +) + +// Variable for mocking resolveUserSpec +var ( + resolveUserSpecFunc = resolveUserSpec + netListen = net.Listen +) + var versionCmd = &cobra.Command{ Use: "version", Short: "Print version information", @@ -160,11 +177,11 @@ func dropPrivileges(username string) (*userSpec, error) { return nil, nil } - if os.Geteuid() != 0 { + if osGeteuid() != 0 { return nil, nil } - spec, err := resolveUserSpec(username) + spec, err := resolveUserSpecFunc(username) if err != nil { return nil, err } @@ -173,13 +190,13 @@ func dropPrivileges(username string) (*userSpec, error) { spec.groups = []int{spec.gid} } - if err := unix.Setgroups(spec.groups); err != nil { + if err := unixSetgroups(spec.groups); err != nil { return nil, fmt.Errorf("setgroups: %w", err) } - if err := unix.Setgid(spec.gid); err != nil { + if err := unixSetgid(spec.gid); err != nil { return nil, fmt.Errorf("setgid: %w", err) } - if err := unix.Setuid(spec.uid); err != nil { + if err := unixSetuid(spec.uid); err != nil { return nil, fmt.Errorf("setuid: %w", err) } @@ -236,10 +253,13 @@ func resolveUserSpec(username string) (*userSpec, error) { return nil, fmt.Errorf("lookup user %q failed: %v (fallback: %w)", username, err, fallbackErr) } +// Variable for testing +var passwdPath = "/etc/passwd" + func lookupUserFromPasswd(username string) (*userSpec, error) { - f, err := os.Open("/etc/passwd") + f, err := os.Open(passwdPath) if err != nil { - return nil, fmt.Errorf("open /etc/passwd: %w", err) + return nil, fmt.Errorf("open %s: %w", passwdPath, err) } defer f.Close() @@ -557,7 +577,7 @@ func (p *Proxy) Start() error { } // Create unix socket listener - listener, err := net.Listen("unix", p.socketPath) + listener, err := netListen("unix", p.socketPath) if err != nil { return fmt.Errorf("failed to create unix socket: %w", err) } diff --git a/cmd/pulse-sensor-proxy/main_test.go b/cmd/pulse-sensor-proxy/main_test.go index cffb87486..4fe4c7cef 100644 --- a/cmd/pulse-sensor-proxy/main_test.go +++ b/cmd/pulse-sensor-proxy/main_test.go @@ -1,15 +1,24 @@ package main import ( + "bytes" "context" "encoding/json" + "errors" + "fmt" + "io" "net" + "net/http" + "net/http/httptest" "os" + "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" + "github.com/rcourtman/pulse-go-rewrite/internal/ssh/knownhosts" "github.com/rs/zerolog" ) @@ -375,8 +384,8 @@ func TestIsProxmoxHost(t *testing.T) { } func TestPeerCapabilitiesFromContext_Nil(t *testing.T) { - if caps := peerCapabilitiesFromContext(nil); caps != 0 { - t.Errorf("expected 0 caps for nil context, got %v", caps) + if caps := peerCapabilitiesFromContext(context.TODO()); caps != 0 { + t.Errorf("expected 0 caps for nil (TODO) context, got %v", caps) } if caps := peerCapabilitiesFromContext(context.Background()); caps != 0 { t.Errorf("expected 0 caps for empty context, got %v", caps) @@ -510,3 +519,1024 @@ func TestHandleRegisterNodesV2(t *testing.T) { t.Errorf("handleRegisterNodesV2 failed: %v", err) } } + +func TestResolveUserSpec(t *testing.T) { + // Only test lookup failure for non-existent user on Linux + if _, err := resolveUserSpec("non-existent-user-12345"); err == nil { + t.Error("expected error for non-existent user") + } + + // We can't easily test success without knowing a valid user on the system + // But we can test the fallback if we mock /etc/passwd or similar, + // however resolveUserSpec reads directly from system or file. + // Let's create a temporary /etc/passwd file and use it? + // lookupUserFromPasswd uses a hardcoded /etc/passwd. + // So we can only test the failure case reliably across all envs. +} + +func TestDropPrivileges(t *testing.T) { + // Should return nil if username is empty + if spec, err := dropPrivileges(""); err != nil || spec != nil { + t.Errorf("expected nil spec and nil error for empty username, got %v, %v", spec, err) + } + + // Should return nil if not root (assuming test is running as non-root) + if os.Geteuid() != 0 { + if spec, err := dropPrivileges("root"); err != nil || spec != nil { + t.Errorf("expected nil spec and nil error when not root, got %v, %v", spec, err) + } + } + // If running as root, this test might behave differently, but usually tests run as non-root. +} + +func TestProxyStartStop(t *testing.T) { + // Create temp dirs + socketDir := t.TempDir() + sshDir := t.TempDir() + socketPath := filepath.Join(socketDir, "test.sock") + + // Ensure ssh keys exist so Start doesn't try to run ssh-keygen (which might be missing or fail) + os.WriteFile(filepath.Join(sshDir, "id_ed25519"), []byte("priv"), 0600) + os.WriteFile(filepath.Join(sshDir, "id_ed25519.pub"), []byte("pub"), 0644) + os.WriteFile(filepath.Join(sshDir, "known_hosts"), []byte(""), 0644) + + metrics := NewProxyMetrics("test") + p := &Proxy{ + socketPath: socketPath, + sshKeyPath: sshDir, + metrics: metrics, + } + + // Initialize known hosts + km, err := knownhosts.NewManager(filepath.Join(sshDir, "known_hosts")) + if err != nil { + t.Fatal(err) + } + p.knownHosts = km + + if err := p.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + + // Verify socket exists + if _, err := os.Stat(socketPath); err != nil { + t.Errorf("Socket file not created") + } + + // Stop + p.Stop() + + // Verify socket removed + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Errorf("Socket file not removed") + } +} + +func TestVersionCmd(t *testing.T) { + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + versionCmd.Run(versionCmd, []string{}) + + w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + if !strings.Contains(output, "pulse-sensor-proxy") { + t.Errorf("expected version output to contain 'pulse-sensor-proxy', got %q", output) + } +} + +func TestKeysCmd(t *testing.T) { + // Mock SSH dir env + tmpDir := t.TempDir() + os.Setenv("PULSE_SENSOR_PROXY_SSH_DIR", tmpDir) + defer os.Unsetenv("PULSE_SENSOR_PROXY_SSH_DIR") + + // Write dummy key + pubKeyPath := filepath.Join(tmpDir, "id_ed25519.pub") + os.WriteFile(pubKeyPath, []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKqy"), 0644) + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + keysCmd.Run(keysCmd, []string{}) + + w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + if !strings.Contains(output, "Proxy Public Key:") { + t.Errorf("expected keys output to contain 'Proxy Public Key:', got %q", output) + } +} + +func TestHandleConnection(t *testing.T) { + // Reduce timeouts for testing + p := &Proxy{ + readTimeout: 1 * time.Second, + writeTimeout: 1 * time.Second, + metrics: NewProxyMetrics("test"), + router: map[string]handlerFunc{ + RPCGetStatus: func(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error) { + return map[string]string{"status": "ok"}, nil + }, + }, + allowedPeerUIDs: map[uint32]struct{}{1000: {}}, + peerCapabilities: map[uint32]Capability{ + 1000: capabilityLegacyAll, + }, + } + + // Mock extractPeerCredentials + origExtract := extractPeerCredentials + defer func() { extractPeerCredentials = origExtract }() + extractPeerCredentials = func(conn net.Conn) (*peerCredentials, error) { + return &peerCredentials{uid: 1000, gid: 1000, pid: 123}, nil + } + + p.rateLimiter = newRateLimiter(p.metrics, &RateLimitConfig{}, nil, nil) + + client, server := net.Pipe() + defer client.Close() + + go p.handleConnection(server) + + // Send request + req := RPCRequest{ + Method: RPCGetStatus, + CorrelationID: "123", + } + bytes, _ := json.Marshal(req) + client.Write(bytes) + client.Write([]byte("\n")) + + // Read response + decoder := json.NewDecoder(client) + var resp RPCResponse + err := decoder.Decode(&resp) + if err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if !resp.Success { + t.Errorf("expected success, got error: %s", resp.Error) + } +} + +func TestGetTemperatureViaSSH_Success(t *testing.T) { + // Mock exec commands + origExec := execCommandFunc + origExecCtx := execCommandContextFunc + defer func() { + execCommandFunc = origExec + execCommandContextFunc = origExecCtx + }() + + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + // Mock hostname -f + if name == "hostname" && len(arg) > 0 && arg[0] == "-f" { + return errorExecCommand("unexpected command") + } + // Mock pvecm status + if name == "pvecm" { + return errorExecCommand("pvecm not found") + } + // Mock ip addr show + if name == "ip" { + return mockExecCommand("127.0.0.1/8") + } + return errorExecCommand("unexpected command: " + name) + } + + execCommandContextFunc = func(ctx context.Context, name string, arg ...string) *exec.Cmd { + args := strings.Join(arg, " ") + if name == "sh" && strings.Contains(args, "ssh") { + // Mock SSH successful output + // We need to output valid JSON for sensors -j + jsonOutput := `{"coretemp-isa-0000":{"Package id 0":{"temp1_input": 42.0}}}` + return mockExecCommand(jsonOutput) + } + return errorExecCommand("unexpected command: " + name) + } + + // Mock ssh key paths + sshDir := t.TempDir() + os.WriteFile(filepath.Join(sshDir, "id_ed25519"), []byte("priv"), 0600) + os.WriteFile(filepath.Join(sshDir, "id_ed25519.pub"), []byte("pub"), 0644) + os.WriteFile(filepath.Join(sshDir, "known_hosts"), []byte(""), 0644) + + // Mock keyscan to avoid calling real ssh-keyscan + mockKeyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + return []byte(fmt.Sprintf("%s ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKqy\n", host)), nil + } + + km, err := knownhosts.NewManager(filepath.Join(sshDir, "known_hosts"), knownhosts.WithKeyscanFunc(mockKeyscan)) + if err != nil { + t.Fatalf("failed to create knownhosts manager: %v", err) + } + p := &Proxy{ + sshKeyPath: sshDir, + knownHosts: km, + metrics: NewProxyMetrics("test"), + maxSSHOutputBytes: 1024, + config: &Config{}, // Initialize config to avoid panic + } + + // This should succeed because we mocked SSH output + temp, err := p.getTemperatureViaSSH(context.Background(), "node1") + if err != nil { + t.Fatalf("getTemperatureViaSSH failed: %v", err) + } + + if !strings.Contains(temp, "42.0") { + t.Errorf("expected temp 42.0, got %s", temp) + } +} + +// Helper for mocking exec.Command +func mockExecCommand(output string) *exec.Cmd { + cs := []string{"-test.run=TestHelperProcess", "--", output} + cmd := exec.Command(os.Args[0], cs...) + cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} + return cmd +} + +func errorExecCommand(msg string) *exec.Cmd { + cs := []string{"-test.run=TestHelperProcess", "--", "ERROR:" + msg} + cmd := exec.Command(os.Args[0], cs...) + cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} + return cmd +} + +// TestHelperProcess isn't a real test. It's used to mock exec.Command +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + args := os.Args + for len(args) > 0 { + if args[0] == "--" { + args = args[1:] + break + } + args = args[1:] + } + if len(args) == 0 { + os.Exit(0) + } + + output := args[0] + if strings.HasPrefix(output, "ERROR:") { + fmt.Fprint(os.Stderr, output[6:]) + os.Exit(1) + } + + fmt.Print(output) + os.Exit(0) +} + +func TestHandleConnection_Errors(t *testing.T) { + // Setup proxy with mocks + p := &Proxy{ + readTimeout: 1 * time.Second, + writeTimeout: 1 * time.Second, + metrics: NewProxyMetrics("test"), + router: map[string]handlerFunc{ + RPCGetStatus: func(ctx context.Context, req *RPCRequest, logger zerolog.Logger) (interface{}, error) { + return map[string]string{"status": "ok"}, nil + }, + }, + allowedPeerUIDs: map[uint32]struct{}{1000: {}}, + peerCapabilities: map[uint32]Capability{ + 1000: capabilityLegacyAll, + }, + config: &Config{}, + } + p.rateLimiter = newRateLimiter(p.metrics, &RateLimitConfig{}, nil, nil) + + // Test 1: Credential Extraction Failure + t.Run("CredFailure", func(t *testing.T) { + origExtract := extractPeerCredentials + defer func() { extractPeerCredentials = origExtract }() + extractPeerCredentials = func(conn net.Conn) (*peerCredentials, error) { + return nil, errors.New("extract failed") + } + + client, server := net.Pipe() + defer client.Close() + go p.handleConnection(server) + + var resp RPCResponse + json.NewDecoder(client).Decode(&resp) + if resp.Success || resp.Error != "unauthorized" { + t.Errorf("expected unauthorized error, got success=%v err=%s", resp.Success, resp.Error) + } + }) + + // Test 2: Unauthorized Peer + t.Run("UnauthorizedPeer", func(t *testing.T) { + origExtract := extractPeerCredentials + defer func() { extractPeerCredentials = origExtract }() + extractPeerCredentials = func(conn net.Conn) (*peerCredentials, error) { + return &peerCredentials{uid: 9999, gid: 9999}, nil // Unknown UID + } + + client, server := net.Pipe() + defer client.Close() + go p.handleConnection(server) + + var resp RPCResponse + json.NewDecoder(client).Decode(&resp) + if resp.Success || resp.Error != "unauthorized" { + t.Errorf("expected unauthorized error, got success=%v err=%s", resp.Success, resp.Error) + } + }) + + // Test 3: Invalid JSON + t.Run("InvalidJSON", func(t *testing.T) { + origExtract := extractPeerCredentials + defer func() { extractPeerCredentials = origExtract }() + extractPeerCredentials = func(conn net.Conn) (*peerCredentials, error) { + return &peerCredentials{uid: 1000, gid: 1000}, nil + } + + client, server := net.Pipe() + defer client.Close() + go p.handleConnection(server) + + client.Write([]byte("invalid-json\n")) + + var resp RPCResponse + json.NewDecoder(client).Decode(&resp) + if resp.Success || resp.Error != "invalid request format" { + t.Errorf("expected invalid format error, got %s", resp.Error) + } + }) + + // Test 4: Unknown Method + t.Run("UnknownMethod", func(t *testing.T) { + origExtract := extractPeerCredentials + defer func() { extractPeerCredentials = origExtract }() + extractPeerCredentials = func(conn net.Conn) (*peerCredentials, error) { + return &peerCredentials{uid: 1000, gid: 1000}, nil + } + + client, server := net.Pipe() + defer client.Close() + go p.handleConnection(server) + + req := RPCRequest{Method: "unknown_method"} + bytes, _ := json.Marshal(req) + client.Write(bytes) + client.Write([]byte("\n")) + + var resp RPCResponse + json.NewDecoder(client).Decode(&resp) + if resp.Success || resp.Error != "unknown method" { + t.Errorf("expected unknown method error, got %s", resp.Error) + } + }) + + // Test 5: Empty Request + t.Run("EmptyRequest", func(t *testing.T) { + origExtract := extractPeerCredentials + defer func() { extractPeerCredentials = origExtract }() + extractPeerCredentials = func(conn net.Conn) (*peerCredentials, error) { + return &peerCredentials{uid: 1000, gid: 1000}, nil + } + + client, server := net.Pipe() + defer client.Close() + go p.handleConnection(server) + + client.Write([]byte("\n")) // Just newline + + var resp RPCResponse + json.NewDecoder(client).Decode(&resp) + if resp.Success || resp.Error != "empty request" { + t.Errorf("expected empty request error, got %s", resp.Error) + } + }) +} + +func TestGetTemperatureViaSSH_Failures(t *testing.T) { + // Mock exec commands + origExec := execCommandFunc + origExecCtx := execCommandContextFunc + defer func() { + execCommandFunc = origExec + execCommandContextFunc = origExecCtx + }() + + // Test command failure + execCommandContextFunc = func(ctx context.Context, name string, arg ...string) *exec.Cmd { + args := strings.Join(arg, " ") + if name == "sh" && strings.Contains(args, "ssh") { + return errorExecCommand("ssh failed") + } + if name == "sensors" { + return errorExecCommand("sensors failed") + } + return mockExecCommand("") + } + + // Mock ssh key paths + sshDir := t.TempDir() + os.WriteFile(filepath.Join(sshDir, "id_ed25519"), []byte("priv"), 0600) + os.WriteFile(filepath.Join(sshDir, "known_hosts"), []byte(""), 0644) + + mockKeyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + return []byte(fmt.Sprintf("%s ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKqy\n", host)), nil + } + km, _ := knownhosts.NewManager(filepath.Join(sshDir, "known_hosts"), knownhosts.WithKeyscanFunc(mockKeyscan)) + + p := &Proxy{ + sshKeyPath: sshDir, + knownHosts: km, + metrics: NewProxyMetrics("test"), + maxSSHOutputBytes: 1024, + config: &Config{}, + } + + // Test SSH failure + _, err := p.getTemperatureViaSSH(context.Background(), "remote-node") + if err == nil { + t.Error("expected error for ssh failure") + } + + // Test Local sensors failure + execCommandContextFunc = func(ctx context.Context, name string, arg ...string) *exec.Cmd { + if name == "sensors" { + return errorExecCommand("sensors failed") + } + return mockExecCommand("") + } + // isLocalNode depends on os.Hostname or netInterfaces + // mock osHostname to match node name + origHostname := osHostname + defer func() { osHostname = origHostname }() + osHostname = func() (string, error) { return "local-node", nil } + + // Should fallback to local sensors and fail + // We need to allow ensureHostKey to succeed or skip it for local node? + // ensureHostKey is called even for local node currently. + // Assume ensureHostKey succeeds due to mockKeyscan. + + _, err = p.getTemperatureViaSSH(context.Background(), "local-node") + // Since "sensors -j" fails, and fallback "sensors" fails, it returns "{}". + // Wait, getTemperatureViaSSH returns "", error if SSH fails. + // Ah, if localNode is true, and SSH fails, it attempts fallback. + // And if fallback succeeds (returns non-empty), it returns string. + // If fallback fails, it returns the SSH error. + + // If sensors fails, getTemperatureLocal returns "{}", nil. + // So it returns "{}", nil if SSH fails and local fallback runs? + // Let's check getTemperatureLocal code. + // If "sensors -j" fails: cmd.Output() returns error. + // It tries "sensors". If that fails, it returns error. + // So getTemperatureLocal returns error. + // If getTemperatureLocal returns error, getTemperatureViaSSH checks `localErr == nil`. + // So if localErr != nil, it falls through to return SSH error. + + if err == nil { + t.Errorf("expected error when both SSH and local sensors fail") + } +} + +func TestFetchAuthorizedNodes(t *testing.T) { + // Mock HTTP server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/temperature-proxy/authorized-nodes" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("X-Proxy-Token") != "test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "nodes": []map[string]string{ + {"name": "node1", "ip": "10.0.0.1"}, + }, + "hash": "abc", + "refresh_interval": 60, + }) + })) + defer ts.Close() + + // Use temp dir for config update + tmpFile, err := os.CreateTemp("", "config.yaml") + if err != nil { + t.Fatal(err) + } + tmpFile.Close() // We just need the path, but empty file is fine as we are mocking writing usually? + // Actually fetchAuthorizedNodes updates p.config.AllowedNodes. + // No file writing happens in fetchAuthorizedNodes. + + p := &Proxy{ + controlPlaneCfg: &ControlPlaneConfig{ + URL: ts.URL, + }, + controlPlaneToken: "test-token", + config: &Config{ + AllowedNodes: []string{}, + }, + metrics: NewProxyMetrics("test"), + } + + // Pre-req: nodeValidator must be set to update allowlist + // But p.nodeValidator is private. + // We can set it if we move it to be accessible or use a constructor. + // However, fetchAuthorizedNodes calls p.nodeValidator.UpdateAllowlist if set. + // Let's rely on it being nil or check if we can set it. + // Oh, nodeValidator is a field in Proxy. + // We can just set it. + + // Create a dummy nodeValidator + validator, _ := newNodeValidator(&Config{}, p.metrics) + p.nodeValidator = validator + + client := ts.Client() + if err := p.fetchAuthorizedNodes(client); err != nil { + t.Fatalf("fetchAuthorizedNodes failed: %v", err) + } + + if len(p.config.AllowedNodes) != 1 || p.config.AllowedNodes[0] != "10.0.0.1" { + t.Errorf("expected allowed nodes [10.0.0.1], got %v", p.config.AllowedNodes) + } +} + +func TestDropPrivileges_Root(t *testing.T) { + // Mock osGeteuid to return 0 (root) + origGeteuid := osGeteuid + origSetgroups := unixSetgroups + origSetgid := unixSetgid + origSetuid := unixSetuid + + defer func() { + osGeteuid = origGeteuid + unixSetgroups = origSetgroups + unixSetgid = origSetgid + unixSetuid = origSetuid + }() + + osGeteuid = func() int { return 0 } + unixSetgroups = func(gids []int) error { return nil } + unixSetgid = func(gid int) error { return nil } + unixSetuid = func(uid int) error { return nil } + + mockSpec := &userSpec{ + name: "testuser", + uid: 1001, + gid: 1001, + groups: []int{1001}, + home: "/home/testuser", + } + + origResolve := resolveUserSpecFunc + defer func() { resolveUserSpecFunc = origResolve }() + resolveUserSpecFunc = func(username string) (*userSpec, error) { + if username == "testuser" { + return mockSpec, nil + } + return nil, errors.New("user not found") + } + + // Test successful drop + spec, err := dropPrivileges("testuser") + if err != nil { + t.Fatalf("dropPrivileges failed: %v", err) + } + if spec.uid != 1001 { + t.Errorf("expected uid 1001, got %d", spec.uid) + } + + // Test syscall failure + unixSetuid = func(uid int) error { return errors.New("setuid failed") } + _, err = dropPrivileges("testuser") + if err == nil { + t.Error("expected error for setuid failure") + } +} + +func TestLookupUserFromPasswd(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "passwd") + content := "root:x:0:0:root:/root:/bin/bash\ntestuser:x:1001:1001:Test User:/home/testuser:/bin/sh\n" + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + origPath := passwdPath + defer func() { passwdPath = origPath }() + passwdPath = tmpFile + + // Test success + spec, err := lookupUserFromPasswd("testuser") + if err != nil { + t.Fatalf("lookupUserFromPasswd failed: %v", err) + } + if spec.uid != 1001 || spec.gid != 1001 || spec.home != "/home/testuser" { + t.Errorf("unexpected spec: %+v", spec) + } + + // Test not found + _, err = lookupUserFromPasswd("nonexistent") + if err == nil { + t.Error("expected error for nonexistent user") + } + + // Test malformed line + os.WriteFile(tmpFile, []byte("malformed\n"), 0644) + _, err = lookupUserFromPasswd("testuser") + if err == nil { + t.Error("expected error/fail for malformed file") + } +} + +func TestStartControlPlaneSync(t *testing.T) { + // Create a token file + tokenFile := filepath.Join(t.TempDir(), "token") + os.WriteFile(tokenFile, []byte("my-token"), 0600) + + p := &Proxy{ + controlPlaneCfg: &ControlPlaneConfig{ + URL: "http://example.com", + TokenFile: tokenFile, + }, + } + + // We want to verify it starts the loop. + // startControlPlaneSync calls `go p.controlPlaneLoop(ctx)` + // We can't inspect the goroutine, but we can check if `controlPlaneCancel` is set. + + if p.controlPlaneCancel != nil { + t.Error("expected cancel to be nil initially") + } + + p.startControlPlaneSync() + + if p.controlPlaneCancel == nil { + t.Error("expected cancel to be set after start") + } + // Clean up + if p.controlPlaneCancel != nil { + p.controlPlaneCancel() + } +} + +func TestResolveUserSpec_Fallback(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "passwd") + // Use a user name that is unlikely to exist on the system to force fallback + fallbackUser := "fallbackuser_9999" + content := fmt.Sprintf("%s:x:2000:2000:Fallback User:/home/%s:/bin/sh\n", fallbackUser, fallbackUser) + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + origPath := passwdPath + defer func() { passwdPath = origPath }() + passwdPath = tmpFile + + spec, err := resolveUserSpec(fallbackUser) + if err != nil { + t.Fatalf("resolveUserSpec failed: %v", err) + } + if spec.name != fallbackUser || spec.uid != 2000 { + t.Errorf("expected spec for %s (uid 2000), got %+v", fallbackUser, spec) + } +} + +func TestProxyStart(t *testing.T) { + origListen := netListen + origExec := execCommandFunc + defer func() { + netListen = origListen + execCommandFunc = origExec + }() + + // Mock successful listener + mockListener := &mockListener{ + addr: &net.UnixAddr{Name: "socket", Net: "unix"}, + closed: make(chan struct{}), + } + netListen = func(network, address string) (net.Listener, error) { + return mockListener, nil + } + + // Mock ssh-keygen success + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + return mockExecCommand("") + } + + tmpDir := t.TempDir() + p := &Proxy{ + sshKeyPath: filepath.Join(tmpDir, "ssh"), + socketPath: filepath.Join(tmpDir, "socket"), + metrics: NewProxyMetrics("test"), + } + + if err := p.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + + if p.listener == nil { + t.Error("expected listener to be set") + } + + // Test Listen failure + netListen = func(network, address string) (net.Listener, error) { + return nil, errors.New("listen failed") + } + if err := p.Start(); err == nil { + t.Error("expected error for listen failure") + } +} + +func TestSSHConnection(t *testing.T) { + origExec := execCommandFunc + defer func() { execCommandFunc = origExec }() + + // Mock ssh success + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + args := strings.Join(arg, " ") + if name == "sh" && strings.Contains(args, "ssh") { + return mockExecCommand("") + } + return errorExecCommand("unexpected") + } + + tmpDir := t.TempDir() + km, _ := knownhosts.NewManager(filepath.Join(tmpDir, "known_hosts"), knownhosts.WithKeyscanFunc(func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + return []byte("host ssh-ed25519 KEY"), nil + })) + + p := &Proxy{ + sshKeyPath: tmpDir, + knownHosts: km, + metrics: NewProxyMetrics("test"), + maxSSHOutputBytes: 1024, + config: &Config{}, // Initialize config + } + + // Ensure dummy key exists + os.WriteFile(filepath.Join(tmpDir, "id_ed25519"), []byte("priv"), 0600) + + if err := p.testSSHConnection("host"); err != nil { + t.Errorf("testSSHConnection failed: %v", err) + } + + // Test failure + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + args := strings.Join(arg, " ") + if name == "sh" && strings.Contains(args, "ssh") { + return errorExecCommand("ssh failed") + } + return mockExecCommand("") + } + + if err := p.testSSHConnection("host"); err == nil { + t.Error("expected error for ssh failure") + } +} + +type mockListener struct { + addr net.Addr + closed chan struct{} +} + +func (m *mockListener) Accept() (net.Conn, error) { + if m.closed != nil { + <-m.closed + } else { + select {} // Block forever if nil + } + return nil, errors.New("listener closed") +} +func (m *mockListener) Close() error { + if m.closed != nil { + select { + case <-m.closed: + default: + close(m.closed) + } + } + return nil +} +func (m *mockListener) Addr() net.Addr { return m.addr } + +func TestEnsureHostKeyFromProxmox(t *testing.T) { + origExec := execCommandFunc + origPath := proxmoxClusterKnownHostsPath + origLookPath := execLookPath + defer func() { + execCommandFunc = origExec + proxmoxClusterKnownHostsPath = origPath + execLookPath = origLookPath + }() + + // Mock isProxmoxHost -> true + execLookPath = func(file string) (string, error) { + if file == "pvecm" { + return "/usr/sbin/pvecm", nil + } + return "", fmt.Errorf("not found") + } + + // Create a dummy known_hosts file simulating /etc/pve/priv/known_hosts + tmpDir := t.TempDir() + knownHostsFile := filepath.Join(tmpDir, "pve_known_hosts") + // Format: host key + if err := os.WriteFile(knownHostsFile, []byte("node1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKqy\n"), 0644); err != nil { + t.Fatal(err) + } + proxmoxClusterKnownHostsPath = knownHostsFile + + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + return mockExecCommand("ok") + } + + proxy := &Proxy{ + knownHosts: &mockKnownHostsManager{}, + metrics: NewProxyMetrics("test"), + } + + // Test success + if err := proxy.ensureHostKeyFromProxmox(context.Background(), "node1"); err != nil { + t.Errorf("ensureHostKeyFromProxmox failed: %v", err) + } + + // Test failure - not proxmox + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + return errorExecCommand("fail") + } + if err := proxy.ensureHostKeyFromProxmox(context.Background(), "node1"); err == nil || err.Error() != "not running on Proxmox host" { + t.Errorf("expected not running on Proxmox host error, got %v", err) + } +} + +type mockKnownHostsManager struct { + ensureErr error +} + +func (m *mockKnownHostsManager) Ensure(ctx context.Context, host string) error { return m.ensureErr } +func (m *mockKnownHostsManager) EnsureWithPort(ctx context.Context, host string, port int) error { + return m.ensureErr +} +func (m *mockKnownHostsManager) EnsureWithEntries(ctx context.Context, host string, port int, entries [][]byte) error { + return m.ensureErr +} +func (m *mockKnownHostsManager) Path() string { return "" } + +func TestPushSSHKey(t *testing.T) { + origExec := execCommandFunc + defer func() { execCommandFunc = origExec }() + + tmpDir := t.TempDir() + proxy := &Proxy{ + sshKeyPath: tmpDir, + config: &Config{ + AllowedSourceSubnets: []string{"192.168.1.0/24"}, + }, + knownHosts: &mockKnownHostsManager{}, + metrics: NewProxyMetrics("test"), + } + os.WriteFile(filepath.Join(tmpDir, "id_ed25519.pub"), []byte("pubkey"), 0644) + + // Mock successful copy + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + // Mock ssh-copy-id or manual command + return mockExecCommand("") + } + + if err := proxy.pushSSHKeyFrom("remote-node", tmpDir); err != nil { + t.Errorf("pushSSHKeyFrom failed: %v", err) + } + + // Test failure + execCommandFunc = func(name string, arg ...string) *exec.Cmd { + return errorExecCommand("fail") + } + if err := proxy.pushSSHKeyFrom("remote-node", tmpDir); err == nil { + t.Error("expected error for failure") + } +} + +func TestLoadProxmoxHostKeys(t *testing.T) { + tmpDir := t.TempDir() + knownHostsFile := filepath.Join(tmpDir, "pve_known_hosts") + // node1 matches. node2 does not match requested host. + content := "node1 ssh-ed25519 AAAKEY1\nnode2 ssh-ed25519 AAAKEY2\n# comment\n" + if err := os.WriteFile(knownHostsFile, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + origPath := proxmoxClusterKnownHostsPath + defer func() { proxmoxClusterKnownHostsPath = origPath }() + proxmoxClusterKnownHostsPath = knownHostsFile + + entries, err := loadProxmoxHostKeys("node1") + if err != nil { + t.Fatalf("loadProxmoxHostKeys failed: %v", err) + } + if len(entries) != 1 { + t.Errorf("expected 1 entry, got %d", len(entries)) + } + + // Test file open error + proxmoxClusterKnownHostsPath = filepath.Join(tmpDir, "nonexistent") + _, err = loadProxmoxHostKeys("node1") + if err == nil { + t.Error("expected error for missing file") + } +} + +func TestHandleHostKeyEnsureError(t *testing.T) { + proxy := &Proxy{ + metrics: NewProxyMetrics("test"), + } + + // Test generic error + err := errors.New("generic error") + if ret := proxy.handleHostKeyEnsureError("node1", err); ret != err { + t.Errorf("expected same error, got %v", ret) + } + + // Test HostKeyChangeError + // We need to construct a HostKeyChangeError. It's from knownhosts package. + // But knownhosts.HostKeyChangeError might be exported. + // If it's not easy to construct, we might need a mock knownhosts manager to return it. + // Let's assume we can mock the error type check or just ensure code path is covered if possible. + // Without referencing the internal/ssh/knownhosts type directly if it's internal? + // It is exported `github.com/rcourtman/pulse-go-rewrite/internal/ssh/knownhosts`. + // Since I imported it as `knownhosts`, I can use it. + + changeErr := &knownhosts.HostKeyChangeError{ + Host: "node1", + Existing: "ssh-ed25519 AAA...", + Provided: "ssh-ed25519 BBB...", + } + + // This should log and record metric, then return the error. + if ret := proxy.handleHostKeyEnsureError("node1", changeErr); ret != changeErr { + t.Errorf("expected same error, got %v", ret) + } +} + +func TestDefaultExtractPeerCredentials(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Skipping peer creds test on non-linux") + } + + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test.sock") + + l, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + + creds, err := defaultExtractPeerCredentials(conn) + if err != nil { + t.Errorf("defaultExtractPeerCredentials failed: %v", err) + return + } + if creds.uid != uint32(os.Geteuid()) { + t.Errorf("expected uid %d, got %d", os.Geteuid(), creds.uid) + } + if creds.gid != uint32(os.Getgid()) { + t.Errorf("expected gid %d, got %d", creds.gid, os.Getgid()) + } + if creds.pid <= 0 { + t.Errorf("expected value pid > 0, got %d", creds.pid) + } + }() + + conn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatal(err) + } + conn.Close() + <-done +} diff --git a/cmd/pulse-sensor-proxy/peer_creds_linux.go b/cmd/pulse-sensor-proxy/peer_creds_linux.go index 018d23b49..c94c1f422 100644 --- a/cmd/pulse-sensor-proxy/peer_creds_linux.go +++ b/cmd/pulse-sensor-proxy/peer_creds_linux.go @@ -10,8 +10,8 @@ import ( "github.com/rs/zerolog/log" ) -// extractPeerCredentials extracts peer credentials via SO_PEERCRED -func extractPeerCredentials(conn net.Conn) (*peerCredentials, error) { +// defaultExtractPeerCredentials extracts peer credentials via SO_PEERCRED +func defaultExtractPeerCredentials(conn net.Conn) (*peerCredentials, error) { unixConn, ok := conn.(*net.UnixConn) if !ok { return nil, fmt.Errorf("not a unix connection") diff --git a/cmd/pulse-sensor-proxy/peer_creds_stub.go b/cmd/pulse-sensor-proxy/peer_creds_stub.go index c651ba2f0..9cacf86b2 100644 --- a/cmd/pulse-sensor-proxy/peer_creds_stub.go +++ b/cmd/pulse-sensor-proxy/peer_creds_stub.go @@ -9,8 +9,8 @@ import ( "github.com/rs/zerolog/log" ) -// extractPeerCredentials is a stub for non-Linux systems -func extractPeerCredentials(conn net.Conn) (*peerCredentials, error) { +// defaultExtractPeerCredentials is a stub for non-Linux systems +func defaultExtractPeerCredentials(conn net.Conn) (*peerCredentials, error) { // On non-Linux systems (like macOS dev), we can't easily get the peer credentials // from the socket. For development purposes, we'll assume the connection // comes from the current user. diff --git a/cmd/pulse-sensor-proxy/ssh.go b/cmd/pulse-sensor-proxy/ssh.go index b94104176..a6e2a4f6d 100644 --- a/cmd/pulse-sensor-proxy/ssh.go +++ b/cmd/pulse-sensor-proxy/ssh.go @@ -32,6 +32,9 @@ var osHostname = os.Hostname // Variable for testing to mock exec.Command (for simple output) var execCommandFunc = exec.Command +// Variable for testing to mock exec.CommandContext +var execCommandContextFunc = exec.CommandContext + const ( tempWrapperPath = "/usr/local/libexec/pulse-sensor-proxy/temp-wrapper.sh" tempWrapperScript = `#!/bin/sh @@ -73,17 +76,17 @@ exit 1 ` ) -const proxmoxClusterKnownHostsPath = "/etc/pve/priv/known_hosts" +var proxmoxClusterKnownHostsPath = "/etc/pve/priv/known_hosts" // execCommand executes a shell command and returns output func execCommand(cmd string) (string, error) { - out, err := exec.Command("sh", "-c", cmd).CombinedOutput() + out, err := execCommandFunc("sh", "-c", cmd).CombinedOutput() return string(out), err } // execCommandWithLimitsContext runs a shell command with output limits and context cancellation func execCommandWithLimitsContext(ctx context.Context, cmd string, stdoutLimit, stderrLimit int64) (string, string, bool, bool, error) { - command := exec.CommandContext(ctx, "sh", "-c", cmd) + command := execCommandContextFunc(ctx, "sh", "-c", cmd) stdoutPipe, err := command.StdoutPipe() if err != nil { @@ -150,7 +153,7 @@ func execCommandWithLimitsContext(ctx context.Context, cmd string, stdoutLimit, } func execCommandWithLimits(cmd string, stdoutLimit, stderrLimit int64) (string, string, bool, bool, error) { - command := exec.Command("sh", "-c", cmd) + command := execCommandFunc("sh", "-c", cmd) stdoutPipe, err := command.StdoutPipe() if err != nil { @@ -666,7 +669,7 @@ func discoverClusterNodes() ([]string, error) { } // Get cluster status with IP addresses - cmd := exec.Command("pvecm", "status") + cmd := execCommandFunc("pvecm", "status") var out, stderr bytes.Buffer cmd.Stdout = &out cmd.Stderr = &stderr @@ -777,7 +780,7 @@ func discoverLocalHostAddresses() ([]string, error) { addresses[strings.ToLower(hostname)] = struct{}{} // Try to get FQDN - cmd := exec.Command("hostname", "-f") + cmd := execCommandFunc("hostname", "-f") if out, err := cmd.Output(); err == nil { fqdn := strings.TrimSpace(string(out)) if fqdn != "" && fqdn != hostname { @@ -885,7 +888,7 @@ func discoverLocalHostAddressesFallback() ([]string, error) { // Get hostname and FQDN (same as native version) if hostname, err := os.Hostname(); err == nil && hostname != "" { addresses[strings.ToLower(hostname)] = struct{}{} - cmd := exec.Command("hostname", "-f") + cmd := execCommandFunc("hostname", "-f") if out, err := cmd.Output(); err == nil { fqdn := strings.TrimSpace(string(out)) if fqdn != "" && fqdn != hostname { @@ -895,7 +898,7 @@ func discoverLocalHostAddressesFallback() ([]string, error) { } // Use 'ip addr' to get IP addresses - cmd := exec.Command("ip", "addr", "show") + cmd := execCommandFunc("ip", "addr", "show") out, err := cmd.Output() if err != nil { log.Warn().Err(err).Msg("Failed to run 'ip addr' command") @@ -1029,11 +1032,11 @@ func isLocalNode(nodeHost string) bool { // getTemperatureLocal collects temperature data from the local machine func (p *Proxy) getTemperatureLocal(ctx context.Context) (string, error) { // Run the same command that the wrapper script runs with context timeout - cmd := exec.CommandContext(ctx, "sensors", "-j") + cmd := execCommandContextFunc(ctx, "sensors", "-j") output, err := cmd.Output() if err != nil { // Try without -j flag as fallback - cmd = exec.CommandContext(ctx, "sensors") + cmd = execCommandContextFunc(ctx, "sensors") if _, err = cmd.Output(); err != nil { return "", fmt.Errorf("failed to run sensors: %w", err) } diff --git a/cmd/pulse-sensor-proxy/validation_test.go b/cmd/pulse-sensor-proxy/validation_test.go index 93880c11a..97538040b 100644 --- a/cmd/pulse-sensor-proxy/validation_test.go +++ b/cmd/pulse-sensor-proxy/validation_test.go @@ -530,8 +530,8 @@ func TestDefaultHostResolver(t *testing.T) { t.Error("expected at least one IP for localhost") } - // Test with nil context - _, _ = r.LookupIP(nil, "localhost") + // Test with nil (TODO) context + _, _ = r.LookupIP(context.TODO(), "localhost") // Test with invalid host _, err = r.LookupIP(context.Background(), "invalid.host.local.test")