Files
Pulse/internal/agentexec/server_websocket_test.go
rcourtman b7a94bad9f security: fix websocket scope and agent impersonation
1. Enforce monitoring:read scope on WebSocket upgrades
   - Prevents low-privilege tokens (e.g. host-agent:report) from accessing
     full infra state via requestData on the main WebSocket.

2. Enforce agent token binding to prevent impersonation
   - Added Metadata field to APITokenRecord to support bound_agent_id
   - Updated agentexec server to validate token-to-agent binding if present
   - Prevents agent:exec tokens from registering as arbitrary agent IDs
2026-02-03 20:40:08 +00:00

365 lines
8.3 KiB
Go

package agentexec
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type wsRawMessage struct {
Type MessageType `json:"type"`
ID string `json:"id,omitempty"`
Timestamp time.Time `json:"timestamp"`
Payload *json.RawMessage `json:"payload,omitempty"`
}
func newWSServer(t *testing.T, s *Server) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.HandleWebSocket(w, r)
}))
}
func wsURLForHTTP(serverURL string) string {
return "ws" + strings.TrimPrefix(serverURL, "http")
}
func wsWriteMessage(t *testing.T, conn *websocket.Conn, msg Message) {
t.Helper()
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := conn.WriteJSON(msg); err != nil {
t.Fatalf("WriteJSON: %v", err)
}
}
func wsReadRawMessage(t *testing.T, conn *websocket.Conn) wsRawMessage {
t.Helper()
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
t.Fatalf("ReadMessage: %v", err)
}
return msg
}
func wsReadRegisteredPayload(t *testing.T, conn *websocket.Conn) RegisteredPayload {
t.Helper()
msg := wsReadRawMessage(t, conn)
if msg.Type != MsgTypeRegistered {
t.Fatalf("message type = %q, want %q", msg.Type, MsgTypeRegistered)
}
if msg.Payload == nil {
t.Fatalf("registered payload missing")
}
var payload RegisteredPayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
t.Fatalf("unmarshal registered payload: %v", err)
}
return payload
}
func wsReadRawMessageWithTimeout(conn *websocket.Conn, timeout time.Duration) (wsRawMessage, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
_, data, err := conn.ReadMessage()
if err != nil {
return wsRawMessage{}, err
}
var msg wsRawMessage
if err := json.Unmarshal(data, &msg); err != nil {
return wsRawMessage{}, err
}
return msg, nil
}
func waitFor(t *testing.T, timeout time.Duration, cond func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("condition not met within %v", timeout)
}
func TestHandleWebSocket_RegistrationSuccessAndDisconnectRemovesAgent(t *testing.T) {
s := NewServer(func(token string, agentID string) bool { return token == "ok" })
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Tags: []string{"tag1"},
Token: "ok",
},
})
reg := wsReadRegisteredPayload(t, conn)
if !reg.Success {
t.Fatalf("registration failed: %q", reg.Message)
}
if !s.IsAgentConnected("a1") {
t.Fatalf("expected agent to be connected")
}
conn.Close()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestHandleWebSocket_InvalidTokenRejected(t *testing.T) {
s := NewServer(func(string, string) bool { return false })
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "bad",
},
})
reg := wsReadRegisteredPayload(t, conn)
if reg.Success {
t.Fatalf("expected registration to be rejected")
}
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
_, _, err = conn.ReadMessage()
if err == nil {
t.Fatalf("expected connection to be closed by server")
}
}
func TestHandleWebSocket_FirstMessageMustBeRegister(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentPing,
Timestamp: time.Now(),
})
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
_, _, err = conn.ReadMessage()
if err == nil {
t.Fatalf("expected server to close connection")
}
}
func TestHandleWebSocket_AgentPingRespondsWithPong(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentPing,
Timestamp: time.Now(),
})
msg := wsReadRawMessage(t, conn)
if msg.Type != MsgTypePong {
t.Fatalf("message type = %q, want %q", msg.Type, MsgTypePong)
}
}
func TestExecuteCommand_RoundTripViaWebSocket(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
agentDone := make(chan struct{})
agentErr := make(chan error, 1)
go func() {
defer close(agentDone)
for {
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
agentErr <- err
return
}
if msg.Type != MsgTypeExecuteCmd {
continue
}
if msg.Payload == nil {
agentErr <- nil
return
}
var payload ExecuteCommandPayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
agentErr <- err
return
}
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := conn.WriteJSON(Message{
Type: MsgTypeCommandResult,
Timestamp: time.Now(),
Payload: CommandResultPayload{
RequestID: payload.RequestID,
Success: true,
Stdout: "ok",
ExitCode: 0,
Duration: 1,
},
}); err != nil {
agentErr <- err
return
}
agentErr <- nil
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, err := s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{
RequestID: "req1",
Command: "echo ok",
Timeout: 1,
})
if err != nil {
t.Fatalf("ExecuteCommand: %v", err)
}
if result == nil || !result.Success || result.Stdout != "ok" || result.ExitCode != 0 {
t.Fatalf("unexpected result: %#v", result)
}
select {
case <-agentDone:
case <-time.After(2 * time.Second):
t.Fatalf("agent goroutine did not finish")
}
if err := <-agentErr; err != nil {
t.Fatalf("agent error: %v", err)
}
}
func TestHandleWebSocket_ReconnectSameAgentIDClosesOldConnection(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
dial := func() *websocket.Conn {
t.Helper()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
return conn
}
c1 := dial()
defer c1.Close()
wsWriteMessage(t, c1, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, c1)
c2 := dial()
defer c2.Close()
wsWriteMessage(t, c2, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, c2)
c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
_, _, err := c1.ReadMessage()
if err == nil {
t.Fatalf("expected old connection to be closed")
}
}