mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 23:41:48 +01:00
152 lines
4.7 KiB
Go
152 lines
4.7 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
pulsews "github.com/rcourtman/pulse-go-rewrite/internal/websocket"
|
|
)
|
|
|
|
func newWebSocketRouter(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
|
|
cfg := newTestConfigWithTokens(t, tokenRecord)
|
|
|
|
hub := pulsews.NewHub(nil)
|
|
hub.SetAllowedOrigins(allowedOrigins)
|
|
go hub.Run()
|
|
|
|
router := NewRouter(cfg, nil, nil, hub, nil, "1.0.0")
|
|
server := httptest.NewServer(router.Handler())
|
|
|
|
cleanup := func() {
|
|
server.Close()
|
|
hub.Stop()
|
|
}
|
|
|
|
return server, cleanup
|
|
}
|
|
|
|
func TestWebSocketOriginRejectedWhenNotAllowed(t *testing.T) {
|
|
rawToken := "ws-origin-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginAllowedWhenConfigured(t *testing.T) {
|
|
rawToken := "ws-origin-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://allowed.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestSocketIOWebSocketOriginRejected(t *testing.T) {
|
|
rawToken := "socket-origin-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/socket.io/?transport=websocket"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection for socket.io")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected socket.io origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginRejectedWhenNoAllowedOriginsAndPublicOrigin(t *testing.T) {
|
|
rawToken := "ws-origin-default-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection with empty allowed origins")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsPrivateWhenNoAllowedOrigins(t *testing.T) {
|
|
rawToken := "ws-origin-default-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "http://localhost:3000")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|