mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
Test: add coverage for auth and security handlers
Add additional tests for OIDC, SAML, and tenant middleware to improve coverage of security-critical paths.
This commit is contained in:
139
internal/api/middleware_tenant_additional_test.go
Normal file
139
internal/api/middleware_tenant_additional_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestNewTenantMiddlewareWithConfig(t *testing.T) {
|
||||
persistence := config.NewMultiTenantPersistence(t.TempDir())
|
||||
checker := stubAuthorizationChecker{}
|
||||
|
||||
mw := NewTenantMiddlewareWithConfig(TenantMiddlewareConfig{
|
||||
Persistence: persistence,
|
||||
AuthChecker: checker,
|
||||
})
|
||||
|
||||
if mw.persistence != persistence {
|
||||
t.Fatalf("expected persistence to be set")
|
||||
}
|
||||
if mw.authChecker == nil {
|
||||
t.Fatalf("expected auth checker to be set")
|
||||
}
|
||||
}
|
||||
|
||||
type stubAuthorizationChecker struct{}
|
||||
|
||||
func (stubAuthorizationChecker) TokenCanAccessOrg(*config.APITokenRecord, string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (stubAuthorizationChecker) UserCanAccessOrg(string, string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (stubAuthorizationChecker) CheckAccess(*config.APITokenRecord, string, string) AuthorizationResult {
|
||||
return AuthorizationResult{Allowed: true}
|
||||
}
|
||||
|
||||
func TestWriteJSONError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeJSONError(rec, http.StatusBadRequest, "bad", "message")
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload["error"] != "bad" || payload["message"] != "message" {
|
||||
t.Fatalf("unexpected payload: %+v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantMiddleware_OrgExtraction(t *testing.T) {
|
||||
mw := NewTenantMiddleware(nil)
|
||||
|
||||
handler := mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := GetOrgID(r.Context())
|
||||
org := GetOrganization(r.Context())
|
||||
if orgID == "" || org == nil || org.ID != orgID {
|
||||
t.Fatalf("unexpected org context: %q %+v", orgID, org)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "header-org")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "pulse_org_id", Value: "cookie-org"})
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
func TestTenantMiddleware_InvalidOrg(t *testing.T) {
|
||||
persistence := config.NewMultiTenantPersistence(t.TempDir())
|
||||
mw := NewTenantMiddleware(persistence)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "../bad")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next handler should not be called")
|
||||
})).ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantMiddleware_MultiTenantDisabled(t *testing.T) {
|
||||
orig := IsMultiTenantEnabled()
|
||||
SetMultiTenantEnabled(false)
|
||||
t.Cleanup(func() { SetMultiTenantEnabled(orig) })
|
||||
|
||||
mw := NewTenantMiddleware(nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "tenant-1")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next handler should not be called")
|
||||
})).ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected 501, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantMiddleware_MultiTenantLicenseRequired(t *testing.T) {
|
||||
orig := IsMultiTenantEnabled()
|
||||
SetMultiTenantEnabled(true)
|
||||
t.Cleanup(func() { SetMultiTenantEnabled(orig) })
|
||||
|
||||
mw := NewTenantMiddleware(nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "tenant-2")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next handler should not be called")
|
||||
})).ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusPaymentRequired {
|
||||
t.Fatalf("expected 402, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
318
internal/api/oidc_handlers_more_test.go
Normal file
318
internal/api/oidc_handlers_more_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func newTestOIDCConfig() *config.OIDCConfig {
|
||||
return &config.OIDCConfig{
|
||||
Enabled: true,
|
||||
IssuerURL: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
RedirectURL: "https://app.example.com/api/oidc/callback",
|
||||
Scopes: []string{"openid", "email"},
|
||||
UsernameClaim: "preferred_username",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
}
|
||||
}
|
||||
|
||||
func newTestOIDCService(cfg *config.OIDCConfig, authURL, tokenURL string) *OIDCService {
|
||||
return &OIDCService{
|
||||
snapshot: oidcSnapshot{
|
||||
issuer: cfg.IssuerURL,
|
||||
clientID: cfg.ClientID,
|
||||
clientSecret: cfg.ClientSecret,
|
||||
redirectURL: cfg.RedirectURL,
|
||||
scopes: append([]string{}, cfg.Scopes...),
|
||||
caBundle: cfg.CABundle,
|
||||
},
|
||||
oauth2Cfg: &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
RedirectURL: cfg.RedirectURL,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: authURL,
|
||||
TokenURL: tokenURL,
|
||||
},
|
||||
Scopes: append([]string{}, cfg.Scopes...),
|
||||
},
|
||||
stateStore: newOIDCStateStore(),
|
||||
}
|
||||
}
|
||||
|
||||
func newOIDCRouterWithService(t *testing.T, authURL, tokenURL string) (*Router, *OIDCService) {
|
||||
t.Helper()
|
||||
cfg := newTestOIDCConfig()
|
||||
svc := newTestOIDCService(cfg, authURL, tokenURL)
|
||||
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
|
||||
t.Cleanup(func() {
|
||||
if svc.stateStore != nil {
|
||||
svc.stateStore.Stop()
|
||||
}
|
||||
})
|
||||
return router, svc
|
||||
}
|
||||
|
||||
func TestHandleOIDCLogin_MethodNotAllowed(t *testing.T) {
|
||||
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/oidc/login", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCLogin(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCLogin_InvalidJSON(t *testing.T) {
|
||||
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", strings.NewReader("{"))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCLogin(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload["code"] != "invalid_request" {
|
||||
t.Fatalf("code = %v, want invalid_request", payload["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCLogin_GetSuccess(t *testing.T) {
|
||||
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/login?returnTo=/dashboard", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCLogin(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
u, err := url.Parse(location)
|
||||
if err != nil {
|
||||
t.Fatalf("parse location: %v", err)
|
||||
}
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Fatalf("unexpected auth host: %q", u.Host)
|
||||
}
|
||||
state := u.Query().Get("state")
|
||||
if state == "" {
|
||||
t.Fatalf("expected state param in redirect")
|
||||
}
|
||||
entry, ok := svc.consumeState(state)
|
||||
if !ok {
|
||||
t.Fatalf("expected state entry to be stored")
|
||||
}
|
||||
if entry.ReturnTo != "/dashboard" {
|
||||
t.Fatalf("returnTo = %q, want /dashboard", entry.ReturnTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCLogin_PostSuccess(t *testing.T) {
|
||||
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
body := strings.NewReader(`{"returnTo":"/home"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", body)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCLogin(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
var payload struct {
|
||||
AuthorizationURL string `json:"authorizationUrl"`
|
||||
}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload.AuthorizationURL == "" {
|
||||
t.Fatalf("expected authorizationUrl in response")
|
||||
}
|
||||
u, err := url.Parse(payload.AuthorizationURL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse authorizationUrl: %v", err)
|
||||
}
|
||||
state := u.Query().Get("state")
|
||||
if state == "" {
|
||||
t.Fatalf("expected state param in authorizationUrl")
|
||||
}
|
||||
entry, ok := svc.consumeState(state)
|
||||
if !ok {
|
||||
t.Fatalf("expected state entry to be stored")
|
||||
}
|
||||
if entry.ReturnTo != "/home" {
|
||||
t.Fatalf("returnTo = %q, want /home", entry.ReturnTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOIDCService_ReturnsCachedService(t *testing.T) {
|
||||
cfg := newTestOIDCConfig()
|
||||
svc := newTestOIDCService(cfg, "https://auth.example.com/authorize", "https://token.example.com")
|
||||
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
|
||||
defer svc.stateStore.Stop()
|
||||
|
||||
got, err := router.getOIDCService(context.Background(), cfg.RedirectURL)
|
||||
if err != nil {
|
||||
t.Fatalf("getOIDCService error: %v", err)
|
||||
}
|
||||
if got != svc {
|
||||
t.Fatalf("expected cached service to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_MethodNotAllowed(t *testing.T) {
|
||||
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/oidc/callback", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_ErrorParam(t *testing.T) {
|
||||
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?error=access_denied", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
if !strings.Contains(location, "oidc_error=access_denied") {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_MissingState(t *testing.T) {
|
||||
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?code=abc", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
if !strings.Contains(location, "oidc_error=missing_state") {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_InvalidState(t *testing.T) {
|
||||
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
if !strings.Contains(location, "oidc_error=invalid_state") {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_MissingCode(t *testing.T) {
|
||||
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
||||
state, _, err := svc.newStateEntry("/dashboard")
|
||||
if err != nil {
|
||||
t.Fatalf("newStateEntry error: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
if !strings.Contains(location, "oidc_error=missing_code") {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
if !strings.HasPrefix(location, "/dashboard") {
|
||||
t.Fatalf("expected redirect back to /dashboard, got %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_ExchangeFailed(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
|
||||
svc.httpClient = tokenServer.Client()
|
||||
state, _, err := svc.newStateEntry("/dashboard")
|
||||
if err != nil {
|
||||
t.Fatalf("newStateEntry error: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
if !strings.Contains(location, "oidc_error=exchange_failed") {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOIDCCallback_MissingIDToken(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token":"access","token_type":"Bearer","expires_in":3600}`))
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
|
||||
svc.httpClient = tokenServer.Client()
|
||||
state, _, err := svc.newStateEntry("/dashboard")
|
||||
if err != nil {
|
||||
t.Fatalf("newStateEntry error: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCCallback(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
||||
}
|
||||
location := rec.Header().Get("Location")
|
||||
if !strings.Contains(location, "oidc_error=missing_id_token") {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
161
internal/api/router_auth_additional_test.go
Normal file
161
internal/api/router_auth_additional_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/pkg/auth"
|
||||
)
|
||||
|
||||
func newAuthRouter(t *testing.T) *Router {
|
||||
t.Helper()
|
||||
hashed, err := auth.HashPassword("currentpassword")
|
||||
if err != nil {
|
||||
t.Fatalf("hash password: %v", err)
|
||||
}
|
||||
return &Router{
|
||||
config: &config.Config{
|
||||
AuthUser: "admin",
|
||||
AuthPass: hashed,
|
||||
ConfigPath: t.TempDir(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChangePassword_InvalidJSON(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader("{"))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleChangePassword(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload["code"] != "invalid_request" {
|
||||
t.Fatalf("expected invalid_request, got %#v", payload["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChangePassword_InvalidPassword(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
body := `{"currentPassword":"currentpassword","newPassword":"short"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleChangePassword(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload["code"] != "invalid_password" {
|
||||
t.Fatalf("expected invalid_password, got %#v", payload["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChangePassword_MissingCurrent(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
body := `{"currentPassword":"","newPassword":"newpassword123"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleChangePassword(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload["code"] != "unauthorized" {
|
||||
t.Fatalf("expected unauthorized, got %#v", payload["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChangePassword_SuccessDocker(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
t.Setenv("PULSE_DOCKER", "true")
|
||||
|
||||
body := `{"currentPassword":"currentpassword","newPassword":"newpassword123"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/change-password", strings.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleChangePassword(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if ok, _ := payload["success"].(bool); !ok {
|
||||
t.Fatalf("expected success=true, got %#v", payload["success"])
|
||||
}
|
||||
envPath := filepath.Join(router.config.ConfigPath, ".env")
|
||||
if _, err := os.Stat(envPath); err != nil {
|
||||
t.Fatalf("expected .env to be written, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResetLockout_InvalidJSON(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader("{"))
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleResetLockout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResetLockout_MissingIdentifier(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":""}`))
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleResetLockout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResetLockout_Success(t *testing.T) {
|
||||
router := newAuthRouter(t)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/reset-lockout", strings.NewReader(`{"identifier":"user1"}`))
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:currentpassword")))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleResetLockout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if ok, _ := payload["success"].(bool); !ok {
|
||||
t.Fatalf("expected success=true, got %#v", payload["success"])
|
||||
}
|
||||
}
|
||||
217
internal/api/saml_handlers_more_test.go
Normal file
217
internal/api/saml_handlers_more_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func newTestSAMLService(t *testing.T, providerID string, metadataXML string) *SAMLService {
|
||||
t.Helper()
|
||||
service, err := NewSAMLService(context.Background(), providerID, &config.SAMLProviderConfig{
|
||||
IDPMetadataXML: metadataXML,
|
||||
}, "https://pulse.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewSAMLService: %v", err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
func TestHandleSAMLACS_ProcessResponseError(t *testing.T) {
|
||||
router := newSAMLRouter(t, testSAMLProvider("okta", true))
|
||||
router.samlManager.services["okta"] = &SAMLService{}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleSAMLACS(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
||||
}
|
||||
if loc := rec.Header().Get("Location"); !strings.Contains(loc, "saml_error=saml_validation_failed") {
|
||||
t.Fatalf("expected validation failed redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSAMLMetadata_InvalidMethod(t *testing.T) {
|
||||
router := newSAMLRouter(t, testSAMLProvider("okta", true))
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/metadata", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleSAMLMetadata(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSAMLMetadata_InvalidProviderID(t *testing.T) {
|
||||
router := newSAMLRouter(t, testSAMLProvider("okta", true))
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/saml/invalid$id/metadata", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleSAMLMetadata(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSAMLSessionInfo_NoCookie(t *testing.T) {
|
||||
router := &Router{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
if info := router.getSAMLSessionInfo(req); info != nil {
|
||||
t.Fatalf("expected nil session info without cookie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSAMLSessionInfo_ReturnsInfo(t *testing.T) {
|
||||
InitSessionStore(t.TempDir())
|
||||
|
||||
token := generateSessionToken()
|
||||
GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
|
||||
ProviderID: "okta",
|
||||
NameID: "name-id",
|
||||
SessionIndex: "sess-1",
|
||||
})
|
||||
|
||||
router := &Router{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
|
||||
|
||||
info := router.getSAMLSessionInfo(req)
|
||||
if info == nil {
|
||||
t.Fatalf("expected session info")
|
||||
}
|
||||
if info.ProviderID != "okta" || info.NameID != "name-id" || info.SessionIndex != "sess-1" {
|
||||
t.Fatalf("unexpected session info: %#v", info)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSession(t *testing.T) {
|
||||
router := &Router{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.clearSession(rec, req)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
if len(cookies) != 1 {
|
||||
t.Fatalf("expected 1 cookie, got %d", len(cookies))
|
||||
}
|
||||
cookie := cookies[0]
|
||||
if cookie.Name != "session" {
|
||||
t.Fatalf("expected session cookie name, got %q", cookie.Name)
|
||||
}
|
||||
if cookie.MaxAge != -1 {
|
||||
t.Fatalf("expected MaxAge -1, got %d", cookie.MaxAge)
|
||||
}
|
||||
if !cookie.HttpOnly {
|
||||
t.Fatalf("expected HttpOnly cookie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSAMLSLO_Redirects(t *testing.T) {
|
||||
router := &Router{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/slo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleSAMLSLO(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
||||
}
|
||||
if loc := rec.Header().Get("Location"); loc != "/?logout=success" {
|
||||
t.Fatalf("unexpected redirect location %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSAMLLogout_SLOUnavailable(t *testing.T) {
|
||||
InitSessionStore(t.TempDir())
|
||||
|
||||
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
|
||||
metadataXML := `<?xml version="1.0"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
|
||||
</IDPSSODescriptor>
|
||||
</EntityDescriptor>`
|
||||
router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
|
||||
|
||||
token := generateSessionToken()
|
||||
GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
|
||||
ProviderID: "okta",
|
||||
NameID: "name-id",
|
||||
SessionIndex: "sess-1",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleSAMLLogout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
||||
}
|
||||
if loc := rec.Header().Get("Location"); loc != "/?logout=success" {
|
||||
t.Fatalf("unexpected redirect location %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSAMLLogout_SLOSuccess(t *testing.T) {
|
||||
InitSessionStore(t.TempDir())
|
||||
|
||||
router := &Router{samlManager: NewSAMLServiceManager("https://pulse.example.com")}
|
||||
metadataXML := `<?xml version="1.0"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
|
||||
<SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
|
||||
</IDPSSODescriptor>
|
||||
</EntityDescriptor>`
|
||||
router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
|
||||
|
||||
token := generateSessionToken()
|
||||
GetSessionStore().CreateSAMLSession(token, time.Hour, "agent", "127.0.0.1", "user", &SAMLTokenInfo{
|
||||
ProviderID: "okta",
|
||||
NameID: "name-id",
|
||||
SessionIndex: "sess-1",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "pulse_session", Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.handleSAMLLogout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusFound, rec.Code)
|
||||
}
|
||||
loc := rec.Header().Get("Location")
|
||||
if !strings.Contains(loc, "https://idp.example.com/slo") || !strings.Contains(loc, "SAMLRequest=") {
|
||||
t.Fatalf("unexpected SLO redirect location %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSAMLProviderID(t *testing.T) {
|
||||
if got := extractSAMLProviderID("/api/saml/okta/login", "login"); got != "okta" {
|
||||
t.Fatalf("expected okta, got %q", got)
|
||||
}
|
||||
if got := extractSAMLProviderID("/api/saml/okta/logout", "login"); got != "" {
|
||||
t.Fatalf("expected empty provider, got %q", got)
|
||||
}
|
||||
if got := extractSAMLProviderID("/api/saml/okta/login/extra", "login"); got != "okta" {
|
||||
t.Fatalf("expected okta for extra path, got %q", got)
|
||||
}
|
||||
if got := extractSAMLProviderID("/api/other/okta/login", "login"); got != "" {
|
||||
t.Fatalf("expected empty provider for non-saml path, got %q", got)
|
||||
}
|
||||
}
|
||||
133
internal/api/saml_service_additional_test.go
Normal file
133
internal/api/saml_service_additional_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestParseIDPMetadataXML_EmptyEntities(t *testing.T) {
|
||||
wrapped := `<?xml version="1.0"?>
|
||||
<EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
|
||||
</EntitiesDescriptor>`
|
||||
if _, err := parseIDPMetadataXML([]byte(wrapped)); err == nil {
|
||||
t.Fatal("expected error for empty entities descriptor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractAttribute(t *testing.T) {
|
||||
service := &SAMLService{}
|
||||
attrs := map[string][]string{
|
||||
"email": {"user@example.com"},
|
||||
}
|
||||
|
||||
if got := service.extractAttribute(attrs, "", "fallback"); got != "fallback" {
|
||||
t.Fatalf("expected fallback value, got %q", got)
|
||||
}
|
||||
if got := service.extractAttribute(attrs, "email", ""); got != "user@example.com" {
|
||||
t.Fatalf("unexpected attribute value: %q", got)
|
||||
}
|
||||
if got := service.extractAttribute(attrs, "missing", "default"); got != "default" {
|
||||
t.Fatalf("unexpected missing attribute value: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessResponse_InvalidResponse(t *testing.T) {
|
||||
service := &SAMLService{
|
||||
sp: &saml.ServiceProvider{},
|
||||
}
|
||||
|
||||
body := strings.NewReader("RelayState=/dashboard&SAMLResponse=")
|
||||
req := httptest.NewRequest(http.MethodPost, "/acs", body)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
_, relay, err := service.ProcessResponse(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid response")
|
||||
}
|
||||
if relay != "/dashboard" {
|
||||
t.Fatalf("unexpected relay state: %q", relay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLServiceIdentifiers(t *testing.T) {
|
||||
service := &SAMLService{
|
||||
providerID: "provider-1",
|
||||
sp: &saml.ServiceProvider{EntityID: "sp-entity"},
|
||||
idpMetadata: &saml.EntityDescriptor{
|
||||
EntityID: "idp-entity",
|
||||
},
|
||||
}
|
||||
|
||||
if service.ProviderID() != "provider-1" {
|
||||
t.Fatalf("unexpected provider id")
|
||||
}
|
||||
if service.GetSPEntityID() != "sp-entity" {
|
||||
t.Fatalf("unexpected sp entity id")
|
||||
}
|
||||
if service.GetIDPEntityID() != "idp-entity" {
|
||||
t.Fatalf("unexpected idp entity id")
|
||||
}
|
||||
|
||||
service.sp = nil
|
||||
if service.GetSPEntityID() != "" {
|
||||
t.Fatalf("expected empty sp entity id when nil")
|
||||
}
|
||||
service.idpMetadata = nil
|
||||
if service.GetIDPEntityID() != "" {
|
||||
t.Fatalf("expected empty idp entity id when nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshMetadata_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<?xml version="1.0"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp-refresh">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"></IDPSSODescriptor>
|
||||
</EntityDescriptor>`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL}
|
||||
service := &SAMLService{
|
||||
providerID: "idp",
|
||||
config: cfg,
|
||||
baseURL: "http://localhost",
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
|
||||
if err := service.RefreshMetadata(context.Background()); err != nil {
|
||||
t.Fatalf("refresh metadata: %v", err)
|
||||
}
|
||||
if service.sp == nil {
|
||||
t.Fatal("expected service provider to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchIDPMetadataFromURL_NonOK(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
service := &SAMLService{httpClient: server.Client()}
|
||||
if _, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL); err == nil {
|
||||
t.Fatal("expected error for non-200 status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddIDPCertificate_InvalidPEM(t *testing.T) {
|
||||
service := &SAMLService{config: &config.SAMLProviderConfig{IDPCertificate: "not-pem"}}
|
||||
metadata := &saml.EntityDescriptor{
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{{}},
|
||||
}
|
||||
if err := service.addIDPCertificate(metadata); err == nil {
|
||||
t.Fatal("expected error for invalid certificate pem")
|
||||
}
|
||||
}
|
||||
95
internal/api/security_oidc_handlers_additional_test.go
Normal file
95
internal/api/security_oidc_handlers_additional_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestSecurityOIDCHandlers_GetConfig(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
PublicURL: "https://pulse.example.com",
|
||||
OIDC: &config.OIDCConfig{
|
||||
Enabled: true,
|
||||
IssuerURL: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "super-secret",
|
||||
RedirectURL: "https://pulse.example.com/oidc/callback",
|
||||
Scopes: []string{"openid"},
|
||||
UsernameClaim: "sub",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
},
|
||||
}
|
||||
router := &Router{config: cfg}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/security/oidc", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCConfig(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
var resp oidcResponse
|
||||
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if !resp.Enabled {
|
||||
t.Fatalf("expected OIDC to be enabled")
|
||||
}
|
||||
if resp.IssuerURL != cfg.OIDC.IssuerURL {
|
||||
t.Fatalf("unexpected issuer url")
|
||||
}
|
||||
if !resp.ClientSecretSet {
|
||||
t.Fatalf("expected client secret to be marked as set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityOIDCHandlers_UpdateSaveFailure(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
PublicURL: "https://pulse.example.com",
|
||||
OIDC: &config.OIDCConfig{
|
||||
ClientSecret: "old-secret",
|
||||
CABundle: "old-bundle",
|
||||
},
|
||||
}
|
||||
router := &Router{config: cfg}
|
||||
|
||||
payload := map[string]any{
|
||||
"enabled": true,
|
||||
"issuerUrl": "https://issuer.example.com",
|
||||
"clientId": "client-id",
|
||||
"clientSecret": "new-secret",
|
||||
"redirectUrl": "https://pulse.example.com/oidc/callback",
|
||||
"scopes": []string{"openid"},
|
||||
"usernameClaim": "sub",
|
||||
"emailClaim": "email",
|
||||
"groupsClaim": "groups",
|
||||
"clearClientSecret": true,
|
||||
"caBundle": "new-bundle",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/security/oidc", bytes.NewReader(body))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
router.handleOIDCConfig(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
var apiErr APIError
|
||||
if err := json.NewDecoder(rr.Body).Decode(&apiErr); err != nil {
|
||||
t.Fatalf("decode error response: %v", err)
|
||||
}
|
||||
if apiErr.Code != "save_failed" {
|
||||
t.Fatalf("unexpected error code: %s", apiErr.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user