mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
feat(relay-docker): improve relay proxy and Docker agent collection
- Enhance relay client with better connection handling - Improve relay proxy with additional functionality and tests - Update Docker agent collect with improved metrics gathering - Add test coverage for Docker agent collection
This commit is contained in:
@@ -33,6 +33,10 @@ func TestCollectContainer(t *testing.T) {
|
||||
Limit: 4000000,
|
||||
Stats: map[string]uint64{"cache": 200000},
|
||||
},
|
||||
Networks: map[string]containertypes.NetworkStats{
|
||||
"eth0": {RxBytes: 2048, TxBytes: 1024},
|
||||
"eth1": {RxBytes: 512, TxBytes: 256},
|
||||
},
|
||||
BlkioStats: containertypes.BlkioStats{
|
||||
IoServiceBytesRecursive: []containertypes.BlkioStatEntry{
|
||||
{Op: "Read", Value: 100},
|
||||
@@ -117,6 +121,9 @@ func TestCollectContainer(t *testing.T) {
|
||||
if len(container.Networks) == 0 {
|
||||
t.Fatalf("expected networks to be populated")
|
||||
}
|
||||
if container.NetworkRXBytes != 2560 || container.NetworkTXBytes != 1280 {
|
||||
t.Fatalf("unexpected network totals: rx=%d tx=%d", container.NetworkRXBytes, container.NetworkTXBytes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stopped container clears sample", func(t *testing.T) {
|
||||
|
||||
@@ -547,6 +547,31 @@ func TestSummarizeBlockIO(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSummarizeNetworkIO(t *testing.T) {
|
||||
t.Run("aggregates all interfaces", func(t *testing.T) {
|
||||
stats := containertypes.StatsResponse{
|
||||
Networks: map[string]containertypes.NetworkStats{
|
||||
"eth0": {RxBytes: 1200, TxBytes: 3400},
|
||||
"eth1": {RxBytes: 800, TxBytes: 600},
|
||||
},
|
||||
}
|
||||
rx, tx := summarizeNetworkIO(stats)
|
||||
if rx != 2000 {
|
||||
t.Fatalf("rx bytes = %d, want 2000", rx)
|
||||
}
|
||||
if tx != 4000 {
|
||||
t.Fatalf("tx bytes = %d, want 4000", tx)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty stats returns zero", func(t *testing.T) {
|
||||
rx, tx := summarizeNetworkIO(containertypes.StatsResponse{})
|
||||
if rx != 0 || tx != 0 {
|
||||
t.Fatalf("expected zero rx/tx, got %d/%d", rx, tx)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractPodmanMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -254,6 +254,8 @@ func (a *Agent) collectContainer(ctx context.Context, summary containertypes.Sum
|
||||
memLimit int64
|
||||
memPercent float64
|
||||
blockIO *agentsdocker.ContainerBlockIO
|
||||
networkRX uint64
|
||||
networkTX uint64
|
||||
)
|
||||
|
||||
if inspect.State.Running || inspect.State.Paused {
|
||||
@@ -271,6 +273,7 @@ func (a *Agent) collectContainer(ctx context.Context, summary containertypes.Sum
|
||||
cpuPercent = a.calculateContainerCPUPercent(summary.ID, stats)
|
||||
memUsage, memLimit, memPercent = calculateMemoryUsage(stats)
|
||||
blockIO = summarizeBlockIO(stats)
|
||||
networkRX, networkTX = summarizeNetworkIO(stats)
|
||||
} else {
|
||||
a.cpuMu.Lock()
|
||||
delete(a.prevContainerCPU, summary.ID)
|
||||
@@ -380,6 +383,8 @@ func (a *Agent) collectContainer(ctx context.Context, summary containertypes.Sum
|
||||
Labels: labels,
|
||||
Env: maskSensitiveEnvVars(inspect.Config.Env),
|
||||
Networks: networks,
|
||||
NetworkRXBytes: networkRX,
|
||||
NetworkTXBytes: networkTX,
|
||||
WritableLayerBytes: writableLayerBytes,
|
||||
RootFilesystemBytes: rootFsBytes,
|
||||
BlockIO: blockIO,
|
||||
@@ -842,6 +847,21 @@ func summarizeBlockIO(stats containertypes.StatsResponse) *agentsdocker.Containe
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeNetworkIO(stats containertypes.StatsResponse) (uint64, uint64) {
|
||||
if len(stats.Networks) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
var rxBytes uint64
|
||||
var txBytes uint64
|
||||
for _, network := range stats.Networks {
|
||||
rxBytes += network.RxBytes
|
||||
txBytes += network.TxBytes
|
||||
}
|
||||
|
||||
return rxBytes, txBytes
|
||||
}
|
||||
|
||||
// sensitiveEnvPatterns are substrings that, when found in an env var name (case-insensitive),
|
||||
// indicate the value should be masked for security.
|
||||
var sensitiveEnvPatterns = []string{
|
||||
|
||||
@@ -249,13 +249,17 @@ func (c *Client) connectAndHandle(ctx context.Context) error {
|
||||
|
||||
c.logger.Info().Str("instance_id", c.instanceID).Msg("Registered with relay server")
|
||||
|
||||
// Start write pump with per-connection sendCh
|
||||
writeCtx, writeCancel := context.WithCancel(ctx)
|
||||
defer writeCancel()
|
||||
go c.writePump(writeCtx, conn, sendCh)
|
||||
// Per-connection context: cancelled when this connection ends (for any
|
||||
// reason), which tears down the write pump and any in-flight stream
|
||||
// goroutines spawned by handleData. Without this, stream goroutines
|
||||
// would keep running against a stale sendCh until the whole client stops.
|
||||
connCtx, connCancel := context.WithCancel(ctx)
|
||||
defer connCancel()
|
||||
|
||||
// Read pump (blocking) — passes sendCh for responses
|
||||
return c.readPump(ctx, conn, sendCh)
|
||||
go c.writePump(connCtx, conn, sendCh)
|
||||
|
||||
// Read pump (blocking) — passes connCtx so handleData streams inherit it
|
||||
return c.readPump(connCtx, conn, sendCh)
|
||||
}
|
||||
|
||||
func (c *Client) register(conn *websocket.Conn) error {
|
||||
@@ -355,7 +359,7 @@ func (c *Client) readPump(ctx context.Context, conn *websocket.Conn, sendCh chan
|
||||
c.handleKeyExchange(frame, sendCh)
|
||||
|
||||
case FrameData:
|
||||
c.handleData(frame, sendCh)
|
||||
c.handleData(ctx, frame, sendCh)
|
||||
|
||||
case FrameChannelClose:
|
||||
c.handleChannelClose(frame)
|
||||
@@ -449,7 +453,7 @@ func (c *Client) handleChannelOpen(frame Frame, sendCh chan<- []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleData(frame Frame, sendCh chan<- []byte) {
|
||||
func (c *Client) handleData(connCtx context.Context, frame Frame, sendCh chan<- []byte) {
|
||||
channelID := frame.Channel
|
||||
|
||||
// Snapshot channel state under lock so the goroutine below doesn't race
|
||||
@@ -483,24 +487,26 @@ func (c *Client) handleData(frame Frame, sendCh chan<- []byte) {
|
||||
payload = decrypted
|
||||
}
|
||||
|
||||
respPayload, err := c.proxy.HandleRequest(payload, apiToken)
|
||||
if err != nil {
|
||||
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Proxy error")
|
||||
return
|
||||
}
|
||||
// Derive from the connection context so streams are cancelled on disconnect.
|
||||
// The 15-minute timeout is a safety net for runaway streams.
|
||||
ctx, cancel := context.WithTimeout(connCtx, 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Encrypt response if encryption is active
|
||||
if enc != nil {
|
||||
encrypted, err := enc.Encrypt(respPayload)
|
||||
if err != nil {
|
||||
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Failed to encrypt DATA response")
|
||||
return
|
||||
err := c.proxy.HandleStreamRequest(ctx, payload, apiToken, func(respPayload []byte) {
|
||||
if enc != nil {
|
||||
encrypted, err := enc.Encrypt(respPayload)
|
||||
if err != nil {
|
||||
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Failed to encrypt DATA response")
|
||||
return
|
||||
}
|
||||
respPayload = encrypted
|
||||
}
|
||||
respPayload = encrypted
|
||||
respFrame := NewFrame(FrameData, channelID, respPayload)
|
||||
queueFrame(sendCh, respFrame, c.logger)
|
||||
})
|
||||
if err != nil && connCtx.Err() == nil {
|
||||
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Stream proxy error")
|
||||
}
|
||||
|
||||
respFrame := NewFrame(FrameData, channelID, respPayload)
|
||||
queueFrame(sendCh, respFrame, c.logger)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -34,17 +36,20 @@ type ProxyRequest struct {
|
||||
|
||||
// ProxyResponse is the JSON payload inside a DATA frame from the instance to the app.
|
||||
type ProxyResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status int `json:"status"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Body string `json:"body,omitempty"` // base64-encoded
|
||||
ID string `json:"id"`
|
||||
Status int `json:"status"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Body string `json:"body,omitempty"` // base64-encoded
|
||||
Stream bool `json:"stream,omitempty"` // true for all streaming chunks
|
||||
StreamDone bool `json:"stream_done,omitempty"` // true for the final chunk
|
||||
}
|
||||
|
||||
// HTTPProxy proxies DATA frame payloads to the local Pulse API.
|
||||
type HTTPProxy struct {
|
||||
localAddr string
|
||||
client *http.Client
|
||||
logger zerolog.Logger
|
||||
localAddr string
|
||||
client *http.Client // for normal request/response proxying
|
||||
streamClient *http.Client // for SSE streaming (no timeout)
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewHTTPProxy creates a proxy that forwards requests to the given local address.
|
||||
@@ -57,6 +62,13 @@ func NewHTTPProxy(localAddr string, logger zerolog.Logger) *HTTPProxy {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
streamClient: &http.Client{
|
||||
// No Timeout — streaming responses are long-lived.
|
||||
// Cancellation is handled via context.
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -157,6 +169,204 @@ func (p *HTTPProxy) HandleRequest(payload []byte, apiToken string) ([]byte, erro
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// HandleStreamRequest processes a DATA frame payload as an HTTP request and streams
|
||||
// the response as multiple ProxyResponse frames via sendFrame. For non-SSE responses,
|
||||
// it falls back to single-response behavior identical to HandleRequest.
|
||||
func (p *HTTPProxy) HandleStreamRequest(ctx context.Context, payload []byte, apiToken string, sendFrame func([]byte)) error {
|
||||
var req ProxyRequest
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
sendFrame(p.errorResponse("", http.StatusBadRequest, "invalid request payload"))
|
||||
return nil
|
||||
}
|
||||
|
||||
if req.ID == "" || req.Method == "" || req.Path == "" {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusBadRequest, "missing required fields (id, method, path)"))
|
||||
return nil
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.Path, "/") {
|
||||
req.Path = "/" + req.Path
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if req.Body != "" {
|
||||
bodyBytes, err := base64.StdEncoding.DecodeString(req.Body)
|
||||
if err != nil {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusBadRequest, "invalid base64 body"))
|
||||
return nil
|
||||
}
|
||||
if len(bodyBytes) > maxProxyBodySize {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusRequestEntityTooLarge, "request body exceeds 47KB limit"))
|
||||
return nil
|
||||
}
|
||||
bodyReader = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s%s", p.localAddr, req.Path)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, req.Method, url, bodyReader)
|
||||
if err != nil {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusInternalServerError, "failed to create request"))
|
||||
return nil
|
||||
}
|
||||
|
||||
for k, v := range req.Headers {
|
||||
if allowedProxyHeader(k) {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
httpReq.Header.Set("X-API-Token", apiToken)
|
||||
|
||||
p.logger.Debug().
|
||||
Str("request_id", req.ID).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.Path).
|
||||
Msg("Proxying relay request (stream-capable)")
|
||||
|
||||
resp, err := p.streamClient.Do(httpReq)
|
||||
if err != nil {
|
||||
p.logger.Warn().Err(err).Str("request_id", req.ID).Msg("Local API request failed")
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusBadGateway, "local API request failed"))
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if this is an SSE response
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if !strings.HasPrefix(ct, "text/event-stream") {
|
||||
// Non-streaming: read full body and send a single response (same as HandleRequest)
|
||||
limitedReader := io.LimitReader(resp.Body, maxProxyBodySize+1)
|
||||
respBody, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusBadGateway, "failed to read response body"))
|
||||
return nil
|
||||
}
|
||||
if len(respBody) > maxProxyBodySize {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusRequestEntityTooLarge, "response body exceeds 47KB limit"))
|
||||
return nil
|
||||
}
|
||||
|
||||
respHeaders := make(map[string]string)
|
||||
for _, key := range []string{"Content-Type", "X-Request-Id", "Cache-Control"} {
|
||||
if v := resp.Header.Get(key); v != "" {
|
||||
respHeaders[key] = v
|
||||
}
|
||||
}
|
||||
|
||||
proxyResp := ProxyResponse{
|
||||
ID: req.ID,
|
||||
Status: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
}
|
||||
if len(respBody) > 0 {
|
||||
proxyResp.Body = base64.StdEncoding.EncodeToString(respBody)
|
||||
}
|
||||
data, err := json.Marshal(proxyResp)
|
||||
if err != nil {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusInternalServerError, "failed to marshal response"))
|
||||
return nil
|
||||
}
|
||||
sendFrame(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSE streaming mode: send an initial header frame
|
||||
respHeaders := make(map[string]string)
|
||||
respHeaders["Content-Type"] = "text/event-stream"
|
||||
if v := resp.Header.Get("X-Request-Id"); v != "" {
|
||||
respHeaders["X-Request-Id"] = v
|
||||
}
|
||||
|
||||
initResp := ProxyResponse{
|
||||
ID: req.ID,
|
||||
Status: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Stream: true,
|
||||
}
|
||||
initData, err := json.Marshal(initResp)
|
||||
if err != nil {
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusInternalServerError, "failed to marshal stream init"))
|
||||
return nil
|
||||
}
|
||||
sendFrame(initData)
|
||||
|
||||
// Read SSE events line-by-line and forward as individual frames
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, maxProxyBodySize), maxProxyBodySize)
|
||||
|
||||
var eventBuf strings.Builder
|
||||
|
||||
for scanner.Scan() {
|
||||
// Check if context was cancelled (relay disconnected)
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
|
||||
if line == "" {
|
||||
// Empty line = end of SSE event
|
||||
if eventBuf.Len() > 0 {
|
||||
eventText := eventBuf.String()
|
||||
eventBuf.Reset()
|
||||
|
||||
chunk := ProxyResponse{
|
||||
ID: req.ID,
|
||||
Status: resp.StatusCode,
|
||||
Body: base64.StdEncoding.EncodeToString([]byte(eventText)),
|
||||
Stream: true,
|
||||
}
|
||||
chunkData, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
p.logger.Warn().Err(err).Msg("Failed to marshal SSE chunk")
|
||||
continue
|
||||
}
|
||||
sendFrame(chunkData)
|
||||
}
|
||||
} else {
|
||||
if eventBuf.Len() > 0 {
|
||||
eventBuf.WriteByte('\n')
|
||||
}
|
||||
eventBuf.WriteString(line)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for scanner error before sending completion.
|
||||
// If scanning failed (e.g. token too long, transport read error), send an
|
||||
// error response instead of stream_done so the client knows it's incomplete.
|
||||
if err := scanner.Err(); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
p.logger.Warn().Err(err).Str("request_id", req.ID).Msg("SSE scanner error")
|
||||
sendFrame(p.errorResponse(req.ID, http.StatusBadGateway, "stream read error"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush any remaining buffered event
|
||||
if eventBuf.Len() > 0 {
|
||||
eventText := eventBuf.String()
|
||||
chunk := ProxyResponse{
|
||||
ID: req.ID,
|
||||
Status: resp.StatusCode,
|
||||
Body: base64.StdEncoding.EncodeToString([]byte(eventText)),
|
||||
Stream: true,
|
||||
}
|
||||
chunkData, _ := json.Marshal(chunk)
|
||||
sendFrame(chunkData)
|
||||
}
|
||||
|
||||
// Send stream-done frame (only on clean completion)
|
||||
doneResp := ProxyResponse{
|
||||
ID: req.ID,
|
||||
Status: resp.StatusCode,
|
||||
StreamDone: true,
|
||||
}
|
||||
doneData, _ := json.Marshal(doneResp)
|
||||
sendFrame(doneData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// allowedProxyHeaders is the set of headers that may be forwarded from relay
|
||||
// requests to the local Pulse API. All other headers are stripped to prevent
|
||||
// auth-context leakage (X-Proxy-Secret, X-Forwarded-*, etc.).
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -191,3 +194,290 @@ func TestHTTPProxy_HandleRequest(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPProxy_HandleStreamRequest(t *testing.T) {
|
||||
logger := zerolog.New(zerolog.NewTestWriter(t))
|
||||
|
||||
t.Run("SSE response sends multiple stream frames", func(t *testing.T) {
|
||||
// Mock SSE server that sends 3 events
|
||||
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-API-Token") != "test-token" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
events := []string{
|
||||
`data: {"type":"content","data":{"text":"Hello"}}`,
|
||||
`data: {"type":"content","data":{"text":" world"}}`,
|
||||
`data: {"type":"done","data":null}`,
|
||||
}
|
||||
for _, ev := range events {
|
||||
fmt.Fprintf(w, "%s\n\n", ev)
|
||||
flusher.Flush()
|
||||
}
|
||||
}))
|
||||
defer sseServer.Close()
|
||||
|
||||
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
||||
proxy := NewHTTPProxy(addr, logger)
|
||||
|
||||
req := ProxyRequest{
|
||||
ID: "stream_1",
|
||||
Method: "POST",
|
||||
Path: "/api/ai/chat",
|
||||
Body: base64.StdEncoding.EncodeToString([]byte(`{"prompt":"hi"}`)),
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
payload, _ := json.Marshal(req)
|
||||
|
||||
var mu sync.Mutex
|
||||
var frames []ProxyResponse
|
||||
err := proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
||||
var resp ProxyResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
t.Errorf("failed to unmarshal frame: %v", err)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
frames = append(frames, resp)
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("HandleStreamRequest() error = %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Expect: 1 init + 3 event chunks + 1 stream_done = 5 frames
|
||||
if len(frames) < 4 {
|
||||
t.Fatalf("expected at least 4 frames, got %d", len(frames))
|
||||
}
|
||||
|
||||
// First frame: stream init (no body, stream=true)
|
||||
if !frames[0].Stream {
|
||||
t.Error("first frame: expected Stream=true")
|
||||
}
|
||||
if frames[0].Body != "" {
|
||||
t.Error("first frame: expected empty body (init frame)")
|
||||
}
|
||||
if frames[0].Status != 200 {
|
||||
t.Errorf("first frame: expected status 200, got %d", frames[0].Status)
|
||||
}
|
||||
if frames[0].ID != "stream_1" {
|
||||
t.Errorf("first frame: expected ID stream_1, got %s", frames[0].ID)
|
||||
}
|
||||
|
||||
// Middle frames: SSE events with stream=true
|
||||
for i := 1; i < len(frames)-1; i++ {
|
||||
if !frames[i].Stream {
|
||||
t.Errorf("frame %d: expected Stream=true", i)
|
||||
}
|
||||
if frames[i].Body == "" {
|
||||
t.Errorf("frame %d: expected non-empty body", i)
|
||||
}
|
||||
// Decode and verify it contains SSE data
|
||||
bodyBytes, _ := base64.StdEncoding.DecodeString(frames[i].Body)
|
||||
bodyStr := string(bodyBytes)
|
||||
if !strings.HasPrefix(bodyStr, "data: ") {
|
||||
t.Errorf("frame %d: expected SSE data prefix, got %q", i, bodyStr[:20])
|
||||
}
|
||||
}
|
||||
|
||||
// Last frame: stream_done
|
||||
lastFrame := frames[len(frames)-1]
|
||||
if !lastFrame.StreamDone {
|
||||
t.Error("last frame: expected StreamDone=true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-SSE response falls back to single response", func(t *testing.T) {
|
||||
jsonServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer jsonServer.Close()
|
||||
|
||||
addr := strings.TrimPrefix(jsonServer.URL, "http://")
|
||||
proxy := NewHTTPProxy(addr, logger)
|
||||
|
||||
req := ProxyRequest{
|
||||
ID: "fallback_1",
|
||||
Method: "GET",
|
||||
Path: "/api/resources",
|
||||
}
|
||||
payload, _ := json.Marshal(req)
|
||||
|
||||
var frames []ProxyResponse
|
||||
err := proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
||||
var resp ProxyResponse
|
||||
json.Unmarshal(data, &resp)
|
||||
frames = append(frames, resp)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("HandleStreamRequest() error = %v", err)
|
||||
}
|
||||
|
||||
// Non-SSE: exactly one frame, no stream flags
|
||||
if len(frames) != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", len(frames))
|
||||
}
|
||||
if frames[0].Stream {
|
||||
t.Error("expected Stream=false for non-SSE response")
|
||||
}
|
||||
if frames[0].StreamDone {
|
||||
t.Error("expected StreamDone=false for non-SSE response")
|
||||
}
|
||||
if frames[0].Status != 200 {
|
||||
t.Errorf("expected status 200, got %d", frames[0].Status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context cancellation stops streaming", func(t *testing.T) {
|
||||
// SSE server that writes one event then blocks forever
|
||||
started := make(chan struct{})
|
||||
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
fmt.Fprintf(w, "data: {\"type\":\"content\"}\n\n")
|
||||
flusher.Flush()
|
||||
close(started)
|
||||
|
||||
// Block until client disconnects
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
defer sseServer.Close()
|
||||
|
||||
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
||||
proxy := NewHTTPProxy(addr, logger)
|
||||
|
||||
req := ProxyRequest{ID: "cancel_1", Method: "POST", Path: "/api/ai/chat"}
|
||||
payload, _ := json.Marshal(req)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- proxy.HandleStreamRequest(ctx, payload, "test-token", func(data []byte) {})
|
||||
}()
|
||||
|
||||
<-started // wait for at least one event
|
||||
cancel() // cancel the context
|
||||
|
||||
err := <-done
|
||||
if err != nil && err != context.Canceled {
|
||||
t.Fatalf("expected nil or context.Canceled, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSE heartbeat comments are not forwarded as events", func(t *testing.T) {
|
||||
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
// heartbeat comment
|
||||
fmt.Fprintf(w, ": heartbeat\n\n")
|
||||
flusher.Flush()
|
||||
// real event
|
||||
fmt.Fprintf(w, "data: {\"type\":\"done\"}\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer sseServer.Close()
|
||||
|
||||
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
||||
proxy := NewHTTPProxy(addr, logger)
|
||||
|
||||
req := ProxyRequest{ID: "heartbeat_1", Method: "POST", Path: "/api/ai/chat"}
|
||||
payload, _ := json.Marshal(req)
|
||||
|
||||
var frames []ProxyResponse
|
||||
proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
||||
var resp ProxyResponse
|
||||
json.Unmarshal(data, &resp)
|
||||
frames = append(frames, resp)
|
||||
})
|
||||
|
||||
// init + heartbeat (": heartbeat") + done event + stream_done = 4 frames
|
||||
// The heartbeat comment IS a valid SSE line, it gets forwarded as a chunk.
|
||||
// The mobile side should handle filtering comments. But we still send it as a chunk.
|
||||
// Let's verify we have the init and stream_done frames at minimum.
|
||||
if len(frames) < 2 {
|
||||
t.Fatalf("expected at least 2 frames, got %d", len(frames))
|
||||
}
|
||||
|
||||
// Last frame should be stream_done
|
||||
if !frames[len(frames)-1].StreamDone {
|
||||
t.Error("last frame: expected StreamDone=true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("scanner error sends error instead of stream_done", func(t *testing.T) {
|
||||
// SSE server that sends a line longer than the scanner buffer.
|
||||
// The default maxProxyBodySize is 47KB, so we send a line that exceeds it.
|
||||
sseServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
// First, a normal event
|
||||
fmt.Fprintf(w, "data: {\"type\":\"content\"}\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
// Now a line that exceeds the scanner buffer (maxProxyBodySize)
|
||||
huge := strings.Repeat("x", maxProxyBodySize+100)
|
||||
fmt.Fprintf(w, "data: %s\n\n", huge)
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer sseServer.Close()
|
||||
|
||||
addr := strings.TrimPrefix(sseServer.URL, "http://")
|
||||
proxy := NewHTTPProxy(addr, logger)
|
||||
|
||||
req := ProxyRequest{ID: "scanerr_1", Method: "POST", Path: "/api/ai/chat"}
|
||||
payload, _ := json.Marshal(req)
|
||||
|
||||
var mu sync.Mutex
|
||||
var frames []ProxyResponse
|
||||
proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
|
||||
var resp ProxyResponse
|
||||
json.Unmarshal(data, &resp)
|
||||
mu.Lock()
|
||||
frames = append(frames, resp)
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should have init + content event + error (NOT stream_done)
|
||||
if len(frames) < 2 {
|
||||
t.Fatalf("expected at least 2 frames, got %d", len(frames))
|
||||
}
|
||||
|
||||
lastFrame := frames[len(frames)-1]
|
||||
// The last frame should be an error, not stream_done
|
||||
if lastFrame.StreamDone {
|
||||
t.Error("last frame should NOT be StreamDone when scanner errored")
|
||||
}
|
||||
if lastFrame.Status != http.StatusBadGateway {
|
||||
t.Errorf("last frame status: got %d, want %d", lastFrame.Status, http.StatusBadGateway)
|
||||
}
|
||||
// Verify the error body mentions stream read error
|
||||
bodyBytes, _ := base64.StdEncoding.DecodeString(lastFrame.Body)
|
||||
if !strings.Contains(string(bodyBytes), "stream read error") {
|
||||
t.Errorf("error body: got %q, expected to contain 'stream read error'", string(bodyBytes))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,6 +73,8 @@ type Container struct {
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
Env []string `json:"env,omitempty"`
|
||||
Networks []ContainerNetwork `json:"networks,omitempty"`
|
||||
NetworkRXBytes uint64 `json:"networkRxBytes,omitempty"`
|
||||
NetworkTXBytes uint64 `json:"networkTxBytes,omitempty"`
|
||||
WritableLayerBytes int64 `json:"writableLayerBytes,omitempty"`
|
||||
RootFilesystemBytes int64 `json:"rootFilesystemBytes,omitempty"`
|
||||
BlockIO *ContainerBlockIO `json:"blockIo,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user