mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-19 07:50:43 +01:00
- Enhance relay client with better connection handling - Improve relay proxy with additional functionality and tests - Update Docker agent collect with improved metrics gathering - Add test coverage for Docker agent collection
484 lines
14 KiB
Go
484 lines
14 KiB
Go
package relay
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
func TestHTTPProxy_HandleRequest(t *testing.T) {
|
|
logger := zerolog.New(zerolog.NewTestWriter(t))
|
|
|
|
// Set up a mock local Pulse API
|
|
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify the API token was injected
|
|
if token := r.Header.Get("X-API-Token"); token != "test-token" {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
|
|
return
|
|
}
|
|
|
|
switch r.URL.Path {
|
|
case "/api/resources":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
|
|
|
case "/api/echo":
|
|
body, _ := io.ReadAll(r.Body)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write(body)
|
|
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
defer mockAPI.Close()
|
|
|
|
// Extract host:port from mock server URL
|
|
addr := strings.TrimPrefix(mockAPI.URL, "http://")
|
|
proxy := NewHTTPProxy(addr, logger)
|
|
|
|
t.Run("GET request", func(t *testing.T) {
|
|
req := ProxyRequest{
|
|
ID: "req_1",
|
|
Method: "GET",
|
|
Path: "/api/resources",
|
|
}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
respPayload, err := proxy.HandleRequest(payload, "test-token")
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest() error = %v", err)
|
|
}
|
|
|
|
var resp ProxyResponse
|
|
if err := json.Unmarshal(respPayload, &resp); err != nil {
|
|
t.Fatalf("unmarshal response: %v", err)
|
|
}
|
|
|
|
if resp.ID != "req_1" {
|
|
t.Errorf("ID: got %q, want %q", resp.ID, "req_1")
|
|
}
|
|
if resp.Status != 200 {
|
|
t.Errorf("Status: got %d, want 200", resp.Status)
|
|
}
|
|
|
|
// Decode body
|
|
bodyBytes, _ := base64.StdEncoding.DecodeString(resp.Body)
|
|
var body map[string]string
|
|
json.Unmarshal(bodyBytes, &body)
|
|
if body["status"] != "ok" {
|
|
t.Errorf("Body status: got %q, want %q", body["status"], "ok")
|
|
}
|
|
})
|
|
|
|
t.Run("POST request with body", func(t *testing.T) {
|
|
bodyContent := `{"data":"hello"}`
|
|
req := ProxyRequest{
|
|
ID: "req_2",
|
|
Method: "POST",
|
|
Path: "/api/echo",
|
|
Body: base64.StdEncoding.EncodeToString([]byte(bodyContent)),
|
|
Headers: map[string]string{
|
|
"Content-Type": "application/json",
|
|
},
|
|
}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
respPayload, err := proxy.HandleRequest(payload, "test-token")
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest() error = %v", err)
|
|
}
|
|
|
|
var resp ProxyResponse
|
|
json.Unmarshal(respPayload, &resp)
|
|
|
|
if resp.Status != 200 {
|
|
t.Errorf("Status: got %d, want 200", resp.Status)
|
|
}
|
|
|
|
bodyBytes, _ := base64.StdEncoding.DecodeString(resp.Body)
|
|
if string(bodyBytes) != bodyContent {
|
|
t.Errorf("Body: got %q, want %q", string(bodyBytes), bodyContent)
|
|
}
|
|
})
|
|
|
|
t.Run("strips sensitive headers", func(t *testing.T) {
|
|
// The mock API verifies X-API-Token is "test-token",
|
|
// so if Authorization header leaked through, it would be ignored anyway.
|
|
// This test verifies the proxy doesn't forward them.
|
|
req := ProxyRequest{
|
|
ID: "req_3",
|
|
Method: "GET",
|
|
Path: "/api/resources",
|
|
Headers: map[string]string{
|
|
"Authorization": "Bearer should-be-stripped",
|
|
"Cookie": "session=should-be-stripped",
|
|
"Host": "evil.example.com",
|
|
"Accept": "application/json",
|
|
},
|
|
}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
respPayload, err := proxy.HandleRequest(payload, "test-token")
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest() error = %v", err)
|
|
}
|
|
|
|
var resp ProxyResponse
|
|
json.Unmarshal(respPayload, &resp)
|
|
|
|
if resp.Status != 200 {
|
|
t.Errorf("Status: got %d, want 200 (headers should have been stripped)", resp.Status)
|
|
}
|
|
})
|
|
|
|
t.Run("invalid JSON payload", func(t *testing.T) {
|
|
respPayload, err := proxy.HandleRequest([]byte("not json"), "test-token")
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest() error = %v", err)
|
|
}
|
|
|
|
var resp ProxyResponse
|
|
json.Unmarshal(respPayload, &resp)
|
|
|
|
if resp.Status != http.StatusBadRequest {
|
|
t.Errorf("Status: got %d, want %d", resp.Status, http.StatusBadRequest)
|
|
}
|
|
})
|
|
|
|
t.Run("missing required fields", func(t *testing.T) {
|
|
req := ProxyRequest{ID: "req_4"} // missing method and path
|
|
payload, _ := json.Marshal(req)
|
|
|
|
respPayload, err := proxy.HandleRequest(payload, "test-token")
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest() error = %v", err)
|
|
}
|
|
|
|
var resp ProxyResponse
|
|
json.Unmarshal(respPayload, &resp)
|
|
|
|
if resp.Status != http.StatusBadRequest {
|
|
t.Errorf("Status: got %d, want %d", resp.Status, http.StatusBadRequest)
|
|
}
|
|
})
|
|
|
|
t.Run("404 from backend", func(t *testing.T) {
|
|
req := ProxyRequest{
|
|
ID: "req_5",
|
|
Method: "GET",
|
|
Path: "/api/nonexistent",
|
|
}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
respPayload, err := proxy.HandleRequest(payload, "test-token")
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest() error = %v", err)
|
|
}
|
|
|
|
var resp ProxyResponse
|
|
json.Unmarshal(respPayload, &resp)
|
|
|
|
if resp.Status != http.StatusNotFound {
|
|
t.Errorf("Status: got %d, want %d", resp.Status, http.StatusNotFound)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHTTPProxy_HandleStreamRequest(t *testing.T) {
|
|
logger := zerolog.New(zerolog.NewTestWriter(t))
|
|
|
|
t.Run("SSE response sends multiple stream frames", func(t *testing.T) {
|
|
// Mock SSE server that sends 3 events
|
|
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("X-API-Token") != "test-token" {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
events := []string{
|
|
`data: {"type":"content","data":{"text":"Hello"}}`,
|
|
`data: {"type":"content","data":{"text":" world"}}`,
|
|
`data: {"type":"done","data":null}`,
|
|
}
|
|
for _, ev := range events {
|
|
fmt.Fprintf(w, "%s\n\n", ev)
|
|
flusher.Flush()
|
|
}
|
|
}))
|
|
defer sseServer.Close()
|
|
|
|
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
|
proxy := NewHTTPProxy(addr, logger)
|
|
|
|
req := ProxyRequest{
|
|
ID: "stream_1",
|
|
Method: "POST",
|
|
Path: "/api/ai/chat",
|
|
Body: base64.StdEncoding.EncodeToString([]byte(`{"prompt":"hi"}`)),
|
|
Headers: map[string]string{
|
|
"Content-Type": "application/json",
|
|
},
|
|
}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
var mu sync.Mutex
|
|
var frames []ProxyResponse
|
|
err := proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
|
var resp ProxyResponse
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
t.Errorf("failed to unmarshal frame: %v", err)
|
|
return
|
|
}
|
|
mu.Lock()
|
|
frames = append(frames, resp)
|
|
mu.Unlock()
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("HandleStreamRequest() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Expect: 1 init + 3 event chunks + 1 stream_done = 5 frames
|
|
if len(frames) < 4 {
|
|
t.Fatalf("expected at least 4 frames, got %d", len(frames))
|
|
}
|
|
|
|
// First frame: stream init (no body, stream=true)
|
|
if !frames[0].Stream {
|
|
t.Error("first frame: expected Stream=true")
|
|
}
|
|
if frames[0].Body != "" {
|
|
t.Error("first frame: expected empty body (init frame)")
|
|
}
|
|
if frames[0].Status != 200 {
|
|
t.Errorf("first frame: expected status 200, got %d", frames[0].Status)
|
|
}
|
|
if frames[0].ID != "stream_1" {
|
|
t.Errorf("first frame: expected ID stream_1, got %s", frames[0].ID)
|
|
}
|
|
|
|
// Middle frames: SSE events with stream=true
|
|
for i := 1; i < len(frames)-1; i++ {
|
|
if !frames[i].Stream {
|
|
t.Errorf("frame %d: expected Stream=true", i)
|
|
}
|
|
if frames[i].Body == "" {
|
|
t.Errorf("frame %d: expected non-empty body", i)
|
|
}
|
|
// Decode and verify it contains SSE data
|
|
bodyBytes, _ := base64.StdEncoding.DecodeString(frames[i].Body)
|
|
bodyStr := string(bodyBytes)
|
|
if !strings.HasPrefix(bodyStr, "data: ") {
|
|
t.Errorf("frame %d: expected SSE data prefix, got %q", i, bodyStr[:20])
|
|
}
|
|
}
|
|
|
|
// Last frame: stream_done
|
|
lastFrame := frames[len(frames)-1]
|
|
if !lastFrame.StreamDone {
|
|
t.Error("last frame: expected StreamDone=true")
|
|
}
|
|
})
|
|
|
|
t.Run("non-SSE response falls back to single response", func(t *testing.T) {
|
|
jsonServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
|
}))
|
|
defer jsonServer.Close()
|
|
|
|
addr := strings.TrimPrefix(jsonServer.URL, "http://")
|
|
proxy := NewHTTPProxy(addr, logger)
|
|
|
|
req := ProxyRequest{
|
|
ID: "fallback_1",
|
|
Method: "GET",
|
|
Path: "/api/resources",
|
|
}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
var frames []ProxyResponse
|
|
err := proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
|
var resp ProxyResponse
|
|
json.Unmarshal(data, &resp)
|
|
frames = append(frames, resp)
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("HandleStreamRequest() error = %v", err)
|
|
}
|
|
|
|
// Non-SSE: exactly one frame, no stream flags
|
|
if len(frames) != 1 {
|
|
t.Fatalf("expected 1 frame, got %d", len(frames))
|
|
}
|
|
if frames[0].Stream {
|
|
t.Error("expected Stream=false for non-SSE response")
|
|
}
|
|
if frames[0].StreamDone {
|
|
t.Error("expected StreamDone=false for non-SSE response")
|
|
}
|
|
if frames[0].Status != 200 {
|
|
t.Errorf("expected status 200, got %d", frames[0].Status)
|
|
}
|
|
})
|
|
|
|
t.Run("context cancellation stops streaming", func(t *testing.T) {
|
|
// SSE server that writes one event then blocks forever
|
|
started := make(chan struct{})
|
|
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
fmt.Fprintf(w, "data: {\"type\":\"content\"}\n\n")
|
|
flusher.Flush()
|
|
close(started)
|
|
|
|
// Block until client disconnects
|
|
<-r.Context().Done()
|
|
}))
|
|
defer sseServer.Close()
|
|
|
|
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
|
proxy := NewHTTPProxy(addr, logger)
|
|
|
|
req := ProxyRequest{ID: "cancel_1", Method: "POST", Path: "/api/ai/chat"}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- proxy.HandleStreamRequest(ctx, payload, "test-token", func(data []byte) {})
|
|
}()
|
|
|
|
<-started // wait for at least one event
|
|
cancel() // cancel the context
|
|
|
|
err := <-done
|
|
if err != nil && err != context.Canceled {
|
|
t.Fatalf("expected nil or context.Canceled, got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("SSE heartbeat comments are not forwarded as events", func(t *testing.T) {
|
|
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
// heartbeat comment
|
|
fmt.Fprintf(w, ": heartbeat\n\n")
|
|
flusher.Flush()
|
|
// real event
|
|
fmt.Fprintf(w, "data: {\"type\":\"done\"}\n\n")
|
|
flusher.Flush()
|
|
}))
|
|
defer sseServer.Close()
|
|
|
|
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
|
proxy := NewHTTPProxy(addr, logger)
|
|
|
|
req := ProxyRequest{ID: "heartbeat_1", Method: "POST", Path: "/api/ai/chat"}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
var frames []ProxyResponse
|
|
proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
|
var resp ProxyResponse
|
|
json.Unmarshal(data, &resp)
|
|
frames = append(frames, resp)
|
|
})
|
|
|
|
// init + heartbeat (": heartbeat") + done event + stream_done = 4 frames
|
|
// The heartbeat comment IS a valid SSE line, it gets forwarded as a chunk.
|
|
// The mobile side should handle filtering comments. But we still send it as a chunk.
|
|
// Let's verify we have the init and stream_done frames at minimum.
|
|
if len(frames) < 2 {
|
|
t.Fatalf("expected at least 2 frames, got %d", len(frames))
|
|
}
|
|
|
|
// Last frame should be stream_done
|
|
if !frames[len(frames)-1].StreamDone {
|
|
t.Error("last frame: expected StreamDone=true")
|
|
}
|
|
})
|
|
|
|
t.Run("scanner error sends error instead of stream_done", func(t *testing.T) {
|
|
// SSE server that sends a line longer than the scanner buffer.
|
|
// The default maxProxyBodySize is 47KB, so we send a line that exceeds it.
|
|
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
// First, a normal event
|
|
fmt.Fprintf(w, "data: {\"type\":\"content\"}\n\n")
|
|
flusher.Flush()
|
|
|
|
// Now a line that exceeds the scanner buffer (maxProxyBodySize)
|
|
huge := strings.Repeat("x", maxProxyBodySize+100)
|
|
fmt.Fprintf(w, "data: %s\n\n", huge)
|
|
flusher.Flush()
|
|
}))
|
|
defer sseServer.Close()
|
|
|
|
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
|
proxy := NewHTTPProxy(addr, logger)
|
|
|
|
req := ProxyRequest{ID: "scanerr_1", Method: "POST", Path: "/api/ai/chat"}
|
|
payload, _ := json.Marshal(req)
|
|
|
|
var mu sync.Mutex
|
|
var frames []ProxyResponse
|
|
proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
|
var resp ProxyResponse
|
|
json.Unmarshal(data, &resp)
|
|
mu.Lock()
|
|
frames = append(frames, resp)
|
|
mu.Unlock()
|
|
})
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Should have init + content event + error (NOT stream_done)
|
|
if len(frames) < 2 {
|
|
t.Fatalf("expected at least 2 frames, got %d", len(frames))
|
|
}
|
|
|
|
lastFrame := frames[len(frames)-1]
|
|
// The last frame should be an error, not stream_done
|
|
if lastFrame.StreamDone {
|
|
t.Error("last frame should NOT be StreamDone when scanner errored")
|
|
}
|
|
if lastFrame.Status != http.StatusBadGateway {
|
|
t.Errorf("last frame status: got %d, want %d", lastFrame.Status, http.StatusBadGateway)
|
|
}
|
|
// Verify the error body mentions stream read error
|
|
bodyBytes, _ := base64.StdEncoding.DecodeString(lastFrame.Body)
|
|
if !strings.Contains(string(bodyBytes), "stream read error") {
|
|
t.Errorf("error body: got %q, expected to contain 'stream read error'", string(bodyBytes))
|
|
}
|
|
})
|
|
}
|