Files
Pulse/internal/agentexec/server_coverage_test.go
2025-12-29 17:25:21 +00:00

538 lines
12 KiB
Go

package agentexec
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type noHijackResponseWriter struct {
header http.Header
}
func (w *noHijackResponseWriter) Header() http.Header {
return w.header
}
func (w *noHijackResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *noHijackResponseWriter) WriteHeader(int) {}
func newConnPair(t *testing.T) (*websocket.Conn, *websocket.Conn, func()) {
t.Helper()
serverConnCh := make(chan *websocket.Conn, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
serverConnCh <- conn
}))
clientConn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
ts.Close()
t.Fatalf("Dial: %v", err)
}
var serverConn *websocket.Conn
select {
case serverConn = <-serverConnCh:
case <-time.After(2 * time.Second):
clientConn.Close()
ts.Close()
t.Fatal("timed out waiting for server connection")
}
cleanup := func() {
clientConn.Close()
serverConn.Close()
ts.Close()
}
return serverConn, clientConn, cleanup
}
func TestHandleWebSocket_UpgradeFailureAndDeadlineErrors(t *testing.T) {
s := NewServer(nil)
req := httptest.NewRequest(http.MethodGet, "http://example/ws", nil)
s.HandleWebSocket(&noHijackResponseWriter{header: make(http.Header)}, req)
}
func TestHandleWebSocket_RegistrationReadError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
conn.Close()
}
func TestHandleWebSocket_RegistrationMessageJSONError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte("{")); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on invalid JSON")
}
}
func TestHandleWebSocket_RegistrationPayloadMarshalError(t *testing.T) {
orig := jsonMarshal
t.Cleanup(func() { jsonMarshal = orig })
jsonMarshal = func(any) ([]byte, error) {
return nil, errors.New("boom")
}
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
},
})
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on marshal error")
}
}
func TestHandleWebSocket_RegistrationPayloadUnmarshalError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"agent_register","payload":"oops"}`)); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on invalid payload")
}
}
func TestHandleWebSocket_PongHandler(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
if err := conn.WriteControl(websocket.PongMessage, []byte("pong"), time.Now().Add(time.Second)); err != nil {
t.Fatalf("WriteControl pong: %v", err)
}
conn.Close()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestReadLoopDone(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
close(ac.done)
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
s.readLoop(ac)
if s.IsAgentConnected("a1") {
t.Fatalf("expected agent to be removed")
}
clientConn.Close()
}
func TestReadLoopUnexpectedCloseError(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
_ = clientConn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseProtocolError, "bye"),
time.Now().Add(time.Second),
)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
}
func TestReadLoopCommandResultBranches(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.pendingReqs["req-full"] = make(chan CommandResultPayload)
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
_ = clientConn.WriteMessage(websocket.TextMessage, []byte("{"))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":123}}`))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-full","success":true}}`))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-missing","success":true}}`))
clientConn.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
s.mu.Lock()
delete(s.pendingReqs, "req-full")
s.mu.Unlock()
}
func TestPingLoopSuccessAndStop(t *testing.T) {
origInterval := pingInterval
t.Cleanup(func() { pingInterval = origInterval })
pingInterval = 5 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
stop := make(chan struct{})
exited := make(chan struct{})
go func() {
s.pingLoop(ac, stop)
close(exited)
}()
time.Sleep(2 * pingInterval)
close(stop)
select {
case <-exited:
case <-time.After(2 * time.Second):
t.Fatalf("pingLoop did not exit")
}
}
func TestPingLoopFailuresClose(t *testing.T) {
origInterval := pingInterval
t.Cleanup(func() { pingInterval = origInterval })
pingInterval = 5 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
serverConn.Close()
stop := make(chan struct{})
exited := make(chan struct{})
go func() {
s.pingLoop(ac, stop)
close(exited)
}()
select {
case <-exited:
case <-time.After(2 * time.Second):
t.Fatalf("pingLoop did not exit after failures")
}
}
func TestSendMessageMarshalError(t *testing.T) {
s := NewServer(nil)
if err := s.sendMessage(nil, Message{Payload: make(chan int)}); err == nil {
t.Fatalf("expected marshal error")
}
}
func TestExecuteCommandSendError(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
serverConn.Close()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r1", Timeout: 1})
if err == nil {
t.Fatalf("expected send error")
}
}
func TestExecuteCommandTimeoutAndCancel(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-timeout", Timeout: 1})
if err == nil || !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error, got %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{RequestID: "r-cancel", Timeout: 1})
if err == nil {
t.Fatalf("expected cancel error")
}
}
func TestExecuteCommandDefaultTimeout(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
go func() {
for {
s.mu.RLock()
ch := s.pendingReqs["r-default"]
s.mu.RUnlock()
if ch != nil {
ch <- CommandResultPayload{RequestID: "r-default", Success: true}
return
}
time.Sleep(2 * time.Millisecond)
}
}()
result, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-default"})
if err != nil || result == nil || !result.Success {
t.Fatalf("expected success, got result=%v err=%v", result, err)
}
}
func TestReadFileRoundTrip(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
agentDone := make(chan error, 1)
go func() {
for {
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
agentDone <- err
return
}
if msg.Type != MsgTypeReadFile || msg.Payload == nil {
continue
}
var payload ReadFilePayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
agentDone <- err
return
}
agentDone <- conn.WriteJSON(Message{
Type: MsgTypeCommandResult,
Timestamp: time.Now(),
Payload: CommandResultPayload{
RequestID: payload.RequestID,
Success: true,
Stdout: "data",
ExitCode: 0,
},
})
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-1", Path: "/etc/hosts"})
if err != nil || result == nil || result.Stdout != "data" {
t.Fatalf("unexpected read file result=%v err=%v", result, err)
}
if err := <-agentDone; err != nil {
t.Fatalf("agent error: %v", err)
}
}
func TestReadFileTimeoutCancelAndSendError(t *testing.T) {
origTimeout := readFileTimeout
t.Cleanup(func() { readFileTimeout = origTimeout })
readFileTimeout = 10 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-timeout"}); err == nil {
t.Fatalf("expected timeout error")
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-cancel"}); err == nil {
t.Fatalf("expected cancel error")
}
serverConn.Close()
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-send"}); err == nil {
t.Fatalf("expected send error")
}
}