From 0df04de4bfe3964b73530cf829f70ffae7d57802 Mon Sep 17 00:00:00 2001 From: rcourtman Date: Tue, 10 Feb 2026 22:10:24 +0000 Subject: [PATCH] feat(cloud): implement Hosted MSP account layer (M-1 through M-6) Add the full Hosted MSP portal backend to the control plane. MSP accounts own multiple container-per-tenant workspaces with consolidated billing, account-level RBAC, and signed JWT handoff for tenant switching. - M-1: Account model + schema (accounts, users, account_memberships tables, tenant.account_id FK, ID generators, full CRUD) - M-2: Account-level RBAC handlers (invite/update/remove members, last-owner protection) - M-3: Portal dashboard API (workspace list with fleet health summaries, workspace detail) - M-4: Tenant handoff auth (JWT HS256 minting with per-tenant secret, auto-submit HTML POST binding, tenant-side exchange endpoint with JTI replay protection) - M-5: Workspace provisioning from portal (create/list/update/delete workspaces, handoff.key generation, billing.json, Docker lifecycle) - M-6: Consolidated billing (stripe_accounts table, MSP-aware webhook routing, subscription state propagation across account tenants) --- go.mod | 1 + go.sum | 2 + internal/api/cloud_handoff_handlers.go | 233 +++++++ internal/cloudcp/account/handlers.go | 346 ++++++++++ internal/cloudcp/account/handlers_test.go | 343 +++++++++ internal/cloudcp/account/tenant_handlers.go | 279 ++++++++ .../cloudcp/account/tenant_handlers_test.go | 220 ++++++ internal/cloudcp/handoff/handler.go | 165 +++++ internal/cloudcp/handoff/handler_test.go | 116 ++++ internal/cloudcp/handoff/handoff.go | 131 ++++ internal/cloudcp/handoff/handoff_test.go | 167 +++++ internal/cloudcp/portal/handlers.go | 202 ++++++ internal/cloudcp/portal/handlers_test.go | 310 +++++++++ internal/cloudcp/registry/models.go | 87 +++ internal/cloudcp/registry/registry.go | 649 +++++++++++++++++- internal/cloudcp/registry/registry_test.go | 390 +++++++++++ internal/cloudcp/routes.go | 10 + internal/cloudcp/stripe/webhook.go | 45 +- 18 files changed, 3684 insertions(+), 12 deletions(-) create mode 100644 internal/api/cloud_handoff_handlers.go create mode 100644 internal/cloudcp/account/handlers.go create mode 100644 internal/cloudcp/account/handlers_test.go create mode 100644 internal/cloudcp/account/tenant_handlers.go create mode 100644 internal/cloudcp/account/tenant_handlers_test.go create mode 100644 internal/cloudcp/handoff/handler.go create mode 100644 internal/cloudcp/handoff/handler_test.go create mode 100644 internal/cloudcp/handoff/handoff.go create mode 100644 internal/cloudcp/handoff/handoff_test.go create mode 100644 internal/cloudcp/portal/handlers.go create mode 100644 internal/cloudcp/portal/handlers_test.go diff --git a/go.mod b/go.mod index bf029497f..5fac24d27 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.7.0 // indirect diff --git a/go.sum b/go.sum index 8b51c1da1..b6318ba29 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= diff --git a/internal/api/cloud_handoff_handlers.go b/internal/api/cloud_handoff_handlers.go new file mode 100644 index 000000000..97209c5ac --- /dev/null +++ b/internal/api/cloud_handoff_handlers.go @@ -0,0 +1,233 @@ +package api + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + + _ "modernc.org/sqlite" +) + +const ( + cloudHandoffIssuer = "pulse-cloud-control-plane" +) + +type cloudHandoffClaims struct { + AccountID string `json:"account_id"` + Email string `json:"email"` + Role string `json:"role"` + jwt.RegisteredClaims +} + +type jtiReplayStore struct { + once sync.Once + db *sql.DB + mu sync.Mutex + + configDir string + initErr error +} + +func (s *jtiReplayStore) init() { + s.once.Do(func() { + dir := filepath.Clean(s.configDir) + if strings.TrimSpace(dir) == "" { + s.initErr = fmt.Errorf("configDir is required") + return + } + if err := os.MkdirAll(dir, 0o755); err != nil { + s.initErr = fmt.Errorf("create config dir: %w", err) + return + } + + dbPath := filepath.Join(dir, "handoff_jti.db") + dsn := dbPath + "?" + url.Values{ + "_pragma": []string{ + "busy_timeout(30000)", + "journal_mode(WAL)", + "synchronous(NORMAL)", + }, + }.Encode() + + db, err := sql.Open("sqlite", dsn) + if err != nil { + s.initErr = fmt.Errorf("open handoff jti db: %w", err) + return + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) + + schema := ` + CREATE TABLE IF NOT EXISTS handoff_jti ( + jti TEXT PRIMARY KEY, + expires_at INTEGER NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_handoff_jti_expires_at ON handoff_jti(expires_at); + ` + if _, err := db.Exec(schema); err != nil { + _ = db.Close() + s.initErr = fmt.Errorf("init handoff jti schema: %w", err) + return + } + + s.db = db + }) +} + +func (s *jtiReplayStore) checkAndStore(jti string, expiresAt time.Time) (stored bool, err error) { + s.init() + if s.initErr != nil { + return false, s.initErr + } + if s.db == nil { + return false, fmt.Errorf("handoff jti store not initialized") + } + jti = strings.TrimSpace(jti) + if jti == "" { + return false, fmt.Errorf("jti is required") + } + expiresAt = expiresAt.UTC() + + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now().UTC().Unix() + if _, err := s.db.Exec(`DELETE FROM handoff_jti WHERE expires_at <= ?`, now); err != nil { + return false, fmt.Errorf("cleanup handoff jti: %w", err) + } + + _, err = s.db.Exec(`INSERT INTO handoff_jti (jti, expires_at) VALUES (?, ?)`, jti, expiresAt.Unix()) + if err != nil { + if isSQLiteUniqueViolation(err) { + return false, nil + } + return false, fmt.Errorf("store handoff jti: %w", err) + } + return true, nil +} + +func isSQLiteUniqueViolation(err error) bool { + if err == nil { + return false + } + s := err.Error() + return strings.Contains(s, "UNIQUE constraint failed") || strings.Contains(s, "constraint failed") +} + +func tenantIDFromRequest(r *http.Request) string { + if v := strings.TrimSpace(os.Getenv("PULSE_TENANT_ID")); v != "" { + return v + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "" + } + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + // Host is expected to be ".". + if i := strings.IndexByte(host, '.'); i > 0 { + return host[:i] + } + return host +} + +// HandleHandoffExchange verifies a control-plane-minted handoff JWT and records its jti to prevent replay. +// +// Route (wiring happens elsewhere): POST /api/cloud/handoff/exchange +// +// This minimal implementation returns success JSON containing user info derived from the token. +// Wiring into RBAC/session minting is intentionally deferred. +func HandleHandoffExchange(configDir string) http.HandlerFunc { + configDir = filepath.Clean(configDir) + keyPath := filepath.Join(configDir, "secrets", "handoff.key") + replay := &jtiReplayStore{configDir: configDir} + + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + key, err := os.ReadFile(keyPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + http.Error(w, "not found", http.StatusNotFound) + return + } + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + tenantID := tenantIDFromRequest(r) + if tenantID == "" { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + tokenStr := strings.TrimSpace(r.FormValue("token")) + if tokenStr == "" { + http.Error(w, "missing token", http.StatusBadRequest) + return + } + + var claims cloudHandoffClaims + parsed, err := jwt.ParseWithClaims( + tokenStr, + &claims, + func(t *jwt.Token) (any, error) { return key, nil }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(cloudHandoffIssuer), + jwt.WithAudience(tenantID), + ) + if err != nil || parsed == nil || !parsed.Valid { + http.Error(w, "invalid token", http.StatusUnauthorized) + return + } + + if claims.ExpiresAt == nil { + http.Error(w, "invalid token", http.StatusUnauthorized) + return + } + + ok, err := replay.checkAndStore(claims.ID, claims.ExpiresAt.Time) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if !ok { + http.Error(w, "replayed token", http.StatusUnauthorized) + return + } + + resp := map[string]any{ + "ok": true, + "tenant_id": tenantID, + "account_id": claims.AccountID, + "user_id": claims.Subject, + "email": claims.Email, + "role": claims.Role, + "jti": claims.ID, + "exp": claims.ExpiresAt.Time.UTC().Format(time.RFC3339), + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } +} diff --git a/internal/cloudcp/account/handlers.go b/internal/cloudcp/account/handlers.go new file mode 100644 index 000000000..c449ee25d --- /dev/null +++ b/internal/cloudcp/account/handlers.go @@ -0,0 +1,346 @@ +package account + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "strings" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +type memberResponse struct { + UserID string `json:"user_id"` + Email string `json:"email"` + Role registry.MemberRole `json:"role"` + CreatedAt time.Time `json:"created_at"` +} + +// HandleListMembers returns an authenticated handler that lists all members of an account. +func HandleListMembers(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + if accountID == "" { + http.Error(w, "missing account_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + memberships, err := reg.ListMembersByAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if memberships == nil { + memberships = []*registry.AccountMembership{} + } + + resp := make([]memberResponse, 0, len(memberships)) + for _, m := range memberships { + u, err := reg.GetUser(m.UserID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if u == nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + resp = append(resp, memberResponse{ + UserID: m.UserID, + Email: u.Email, + Role: m.Role, + CreatedAt: m.CreatedAt, + }) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } +} + +type inviteMemberRequest struct { + Email string `json:"email"` + Role string `json:"role"` +} + +// HandleInviteMember returns an authenticated handler that invites a user to an account. +func HandleInviteMember(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + if accountID == "" { + http.Error(w, "missing account_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + var req inviteMemberRequest + if err := decodeJSON(w, r, &req); err != nil { + return + } + + email := normalizeEmail(req.Email) + if email == "" { + http.Error(w, "invalid email", http.StatusBadRequest) + return + } + + role, ok := parseMemberRole(req.Role) + if !ok { + http.Error(w, "invalid role", http.StatusBadRequest) + return + } + + u, err := reg.GetUserByEmail(email) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if u == nil { + userID, err := registry.GenerateUserID() + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + u = ®istry.User{ + ID: userID, + Email: email, + } + if err := reg.CreateUser(u); err != nil { + // If a concurrent request created the user, fall back to lookup. + u2, gerr := reg.GetUserByEmail(email) + if gerr != nil || u2 == nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + u = u2 + } + } + + if err := reg.CreateMembership(®istry.AccountMembership{ + AccountID: accountID, + UserID: u.ID, + Role: role, + }); err != nil { + if isUniqueViolation(err) { + http.Error(w, "membership already exists", http.StatusConflict) + return + } + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusCreated) + } +} + +type updateMemberRoleRequest struct { + Role string `json:"role"` +} + +// HandleUpdateMemberRole returns an authenticated handler that updates a member's role. +func HandleUpdateMemberRole(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPatch { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + userID := strings.TrimSpace(r.PathValue("user_id")) + if accountID == "" || userID == "" { + http.Error(w, "missing account_id or user_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + var req updateMemberRoleRequest + if err := decodeJSON(w, r, &req); err != nil { + return + } + + role, ok := parseMemberRole(req.Role) + if !ok { + http.Error(w, "invalid role", http.StatusBadRequest) + return + } + + if err := reg.UpdateMembershipRole(accountID, userID, role); err != nil { + if isNotFoundErr(err) { + http.Error(w, "membership not found", http.StatusNotFound) + return + } + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + } +} + +// HandleRemoveMember returns an authenticated handler that removes a user from an account. +func HandleRemoveMember(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + userID := strings.TrimSpace(r.PathValue("user_id")) + if accountID == "" || userID == "" { + http.Error(w, "missing account_id or user_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + m, err := reg.GetMembership(accountID, userID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if m == nil { + http.Error(w, "membership not found", http.StatusNotFound) + return + } + + if m.Role == registry.MemberRoleOwner { + memberships, err := reg.ListMembersByAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + owners := 0 + for _, mm := range memberships { + if mm.Role == registry.MemberRoleOwner { + owners++ + } + } + if owners <= 1 { + http.Error(w, "cannot remove last owner", http.StatusConflict) + return + } + } + + if err := reg.DeleteMembership(accountID, userID); err != nil { + if isNotFoundErr(err) { + http.Error(w, "membership not found", http.StatusNotFound) + return + } + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) + } +} + +func normalizeEmail(s string) string { + s = strings.TrimSpace(s) + s = strings.ToLower(s) + // Minimal sanity; deeper validation comes later with session auth flows. + if s == "" || !strings.Contains(s, "@") { + return "" + } + return s +} + +func parseMemberRole(s string) (registry.MemberRole, bool) { + switch registry.MemberRole(strings.TrimSpace(s)) { + case registry.MemberRoleOwner: + return registry.MemberRoleOwner, true + case registry.MemberRoleAdmin: + return registry.MemberRoleAdmin, true + case registry.MemberRoleTech: + return registry.MemberRoleTech, true + case registry.MemberRoleReadOnly: + return registry.MemberRoleReadOnly, true + default: + return "", false + } +} + +func decodeJSON(w http.ResponseWriter, r *http.Request, dst any) error { + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(dst); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return err + } + if err := dec.Decode(&struct{}{}); err != io.EOF { + if err == nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return errors.New("multiple JSON values") + } + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return err + } + return nil +} + +func isNotFoundErr(err error) bool { + if err == nil { + return false + } + // Registry uses fmt.Errorf("... not found") (no sentinel errors yet). + return strings.Contains(err.Error(), "not found") +} + +func isUniqueViolation(err error) bool { + if err == nil { + return false + } + // modernc.org/sqlite returns strings containing "UNIQUE constraint failed". + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "unique constraint failed") +} diff --git a/internal/cloudcp/account/handlers_test.go b/internal/cloudcp/account/handlers_test.go new file mode 100644 index 000000000..095160f2d --- /dev/null +++ b/internal/cloudcp/account/handlers_test.go @@ -0,0 +1,343 @@ +package account + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "sort" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +func newTestRegistry(t *testing.T) *registry.TenantRegistry { + t.Helper() + dir := t.TempDir() + reg, err := registry.NewTenantRegistry(dir) + if err != nil { + t.Fatalf("NewTenantRegistry: %v", err) + } + t.Cleanup(func() { _ = reg.Close() }) + return reg +} + +func newTestMux(reg *registry.TenantRegistry) *http.ServeMux { + mux := http.NewServeMux() + + listMembers := HandleListMembers(reg) + inviteMember := HandleInviteMember(reg) + updateRole := HandleUpdateMemberRole(reg) + removeMember := HandleRemoveMember(reg) + + collection := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + listMembers(w, r) + case http.MethodPost: + inviteMember(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + member := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPatch: + updateRole(w, r) + case http.MethodDelete: + removeMember(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + mux.Handle("/api/accounts/{account_id}/members", admin.AdminKeyMiddleware("secret-key", collection)) + mux.Handle("/api/accounts/{account_id}/members/{user_id}", admin.AdminKeyMiddleware("secret-key", member)) + return mux +} + +func doRequest(t *testing.T, h http.Handler, req *http.Request) *httptest.ResponseRecorder { + t.Helper() + req.Header.Set("X-Admin-Key", "secret-key") + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + return rec +} + +func TestInviteMember(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + + body := `{"email":"tech@msp.com","role":"tech"}` + req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/members", bytes.NewBufferString(body)) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String()) + } + + u, err := reg.GetUserByEmail("tech@msp.com") + if err != nil { + t.Fatal(err) + } + if u == nil { + t.Fatal("expected user to be created") + } + + m, err := reg.GetMembership(accountID, u.ID) + if err != nil { + t.Fatal(err) + } + if m == nil { + t.Fatal("expected membership to be created") + } + if m.Role != registry.MemberRoleTech { + t.Fatalf("role = %q, want %q", m.Role, registry.MemberRoleTech) + } +} + +func TestInviteExistingUser(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + + userID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: userID, Email: "existing@msp.com"}); err != nil { + t.Fatal(err) + } + + body := `{"email":"existing@msp.com","role":"tech"}` + req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/members", bytes.NewBufferString(body)) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String()) + } + + u, err := reg.GetUserByEmail("existing@msp.com") + if err != nil { + t.Fatal(err) + } + if u == nil || u.ID != userID { + t.Fatalf("user = %+v, want id=%q", u, userID) + } + + m, err := reg.GetMembership(accountID, userID) + if err != nil { + t.Fatal(err) + } + if m == nil { + t.Fatal("expected membership to be created") + } +} + +func TestListMembers(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + + u1ID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + u2ID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: u1ID, Email: "owner@msp.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: u2ID, Email: "tech@msp.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: u1ID, Role: registry.MemberRoleOwner}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: u2ID, Role: registry.MemberRoleTech}); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/accounts/"+accountID+"/members", nil) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + + var got []struct { + UserID string `json:"user_id"` + Email string `json:"email"` + Role registry.MemberRole `json:"role"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(got) != 2 { + t.Fatalf("expected 2 members, got %d (%+v)", len(got), got) + } + + sort.Slice(got, func(i, j int) bool { return got[i].Email < got[j].Email }) + if got[0].Email != "owner@msp.com" || got[0].Role != registry.MemberRoleOwner { + t.Fatalf("member[0]=%+v, want owner@msp.com owner", got[0]) + } + if got[1].Email != "tech@msp.com" || got[1].Role != registry.MemberRoleTech { + t.Fatalf("member[1]=%+v, want tech@msp.com tech", got[1]) + } +} + +func TestUpdateMemberRole(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + + userID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: userID, Email: "tech@msp.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: userID, Role: registry.MemberRoleTech}); err != nil { + t.Fatal(err) + } + + body := `{"role":"admin"}` + req := httptest.NewRequest(http.MethodPatch, "/api/accounts/"+accountID+"/members/"+userID, bytes.NewBufferString(body)) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + + m, err := reg.GetMembership(accountID, userID) + if err != nil { + t.Fatal(err) + } + if m == nil { + t.Fatal("expected membership to exist") + } + if m.Role != registry.MemberRoleAdmin { + t.Fatalf("role = %q, want %q", m.Role, registry.MemberRoleAdmin) + } +} + +func TestRemoveMember(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + + ownerID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + techID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: ownerID, Email: "owner@msp.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: techID, Email: "tech@msp.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: ownerID, Role: registry.MemberRoleOwner}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: techID, Role: registry.MemberRoleTech}); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/members/"+techID, nil) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusNoContent, rec.Body.String()) + } + + m, err := reg.GetMembership(accountID, techID) + if err != nil { + t.Fatal(err) + } + if m != nil { + t.Fatalf("expected membership to be deleted, got %+v", m) + } +} + +func TestCannotRemoveLastOwner(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + + ownerID, err := registry.GenerateUserID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: ownerID, Email: "owner@msp.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: ownerID, Role: registry.MemberRoleOwner}); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/members/"+ownerID, nil) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusConflict && rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d or %d (body=%q)", rec.Code, http.StatusConflict, http.StatusBadRequest, rec.Body.String()) + } + + m, err := reg.GetMembership(accountID, ownerID) + if err != nil { + t.Fatal(err) + } + if m == nil { + t.Fatal("expected owner membership to remain") + } +} diff --git a/internal/cloudcp/account/tenant_handlers.go b/internal/cloudcp/account/tenant_handlers.go new file mode 100644 index 000000000..d77f3c4bc --- /dev/null +++ b/internal/cloudcp/account/tenant_handlers.go @@ -0,0 +1,279 @@ +package account + +import ( + "context" + "encoding/json" + "net/http" + "strings" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +// WorkspaceProvisioner is the minimal interface needed by the MSP portal tenant handlers. +// Implemented by internal/cloudcp/stripe.Provisioner. +type WorkspaceProvisioner interface { + ProvisionWorkspace(ctx context.Context, accountID, displayName string) (*registry.Tenant, error) + DeprovisionWorkspaceContainer(ctx context.Context, tenant *registry.Tenant) error +} + +// HandleListTenants lists all tenants for an account. +// Route: GET /api/accounts/{account_id}/tenants +func HandleListTenants(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + if accountID == "" { + http.Error(w, "missing account_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + tenants, err := reg.ListByAccountID(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if tenants == nil { + tenants = []*registry.Tenant{} + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tenants) + } +} + +type createTenantRequest struct { + DisplayName string `json:"display_name"` +} + +// HandleCreateTenant creates a new tenant under an account. +// Route: POST /api/accounts/{account_id}/tenants +func HandleCreateTenant(reg *registry.TenantRegistry, provisioner WorkspaceProvisioner) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + if accountID == "" { + http.Error(w, "missing account_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + var req createTenantRequest + if err := decodeJSON(w, r, &req); err != nil { + return + } + displayName := strings.TrimSpace(req.DisplayName) + if displayName == "" { + http.Error(w, "invalid display_name", http.StatusBadRequest) + return + } + if provisioner == nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + tenant, err := provisioner.ProvisionWorkspace(r.Context(), accountID, displayName) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(tenant) + } +} + +type updateTenantRequest struct { + DisplayName *string `json:"display_name,omitempty"` + Status *string `json:"status,omitempty"` + State *string `json:"state,omitempty"` +} + +func parseTenantState(s string) (registry.TenantState, bool) { + switch registry.TenantState(strings.TrimSpace(s)) { + case registry.TenantStateActive: + return registry.TenantStateActive, true + case registry.TenantStateSuspended: + return registry.TenantStateSuspended, true + default: + return "", false + } +} + +func loadTenantForAccount(reg *registry.TenantRegistry, accountID, tenantID string) (*registry.Tenant, error) { + t, err := reg.Get(tenantID) + if err != nil { + return nil, err + } + if t == nil { + return nil, nil + } + if strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID { + return nil, nil + } + return t, nil +} + +// HandleUpdateTenant updates display name and/or state. +// Route: PATCH /api/accounts/{account_id}/tenants/{tenant_id} +func HandleUpdateTenant(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPatch { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + tenantID := strings.TrimSpace(r.PathValue("tenant_id")) + if accountID == "" || tenantID == "" { + http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + tenant, err := loadTenantForAccount(reg, accountID, tenantID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if tenant == nil { + http.Error(w, "tenant not found", http.StatusNotFound) + return + } + + var req updateTenantRequest + if err := decodeJSON(w, r, &req); err != nil { + return + } + + if req.DisplayName != nil { + name := strings.TrimSpace(*req.DisplayName) + if name == "" { + http.Error(w, "invalid display_name", http.StatusBadRequest) + return + } + tenant.DisplayName = name + } + + stateVal := req.Status + if stateVal == nil { + stateVal = req.State + } + if stateVal != nil { + st, ok := parseTenantState(*stateVal) + if !ok { + http.Error(w, "invalid status", http.StatusBadRequest) + return + } + tenant.State = st + } + + if err := reg.Update(tenant); err != nil { + if isNotFoundErr(err) { + http.Error(w, "tenant not found", http.StatusNotFound) + return + } + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tenant) + } +} + +// HandleDeleteTenant soft-deletes a tenant and deprovisions its container if Docker is available. +// Route: DELETE /api/accounts/{account_id}/tenants/{tenant_id} +func HandleDeleteTenant(reg *registry.TenantRegistry, provisioner WorkspaceProvisioner) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + tenantID := strings.TrimSpace(r.PathValue("tenant_id")) + if accountID == "" || tenantID == "" { + http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + tenant, err := loadTenantForAccount(reg, accountID, tenantID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if tenant == nil { + http.Error(w, "tenant not found", http.StatusNotFound) + return + } + + tenant.State = registry.TenantStateDeleting + if err := reg.Update(tenant); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + if provisioner != nil { + if err := provisioner.DeprovisionWorkspaceContainer(r.Context(), tenant); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + + tenant.ContainerID = "" + tenant.State = registry.TenantStateDeleted + if err := reg.Update(tenant); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) + } +} diff --git a/internal/cloudcp/account/tenant_handlers_test.go b/internal/cloudcp/account/tenant_handlers_test.go new file mode 100644 index 000000000..fedafaef5 --- /dev/null +++ b/internal/cloudcp/account/tenant_handlers_test.go @@ -0,0 +1,220 @@ +package account + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" + cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe" +) + +func newTestTenantMux(reg *registry.TenantRegistry, tenantsDir string) (*http.ServeMux, *cpstripe.Provisioner) { + mux := http.NewServeMux() + provisioner := cpstripe.NewProvisioner(reg, tenantsDir, nil, nil, "https://cloud.example.com") + + listTenants := HandleListTenants(reg) + createTenant := HandleCreateTenant(reg, provisioner) + updateTenant := HandleUpdateTenant(reg) + deleteTenant := HandleDeleteTenant(reg, provisioner) + + collection := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + listTenants(w, r) + case http.MethodPost: + createTenant(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + tenant := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPatch: + updateTenant(w, r) + case http.MethodDelete: + deleteTenant(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + mux.Handle("/api/accounts/{account_id}/tenants", admin.AdminKeyMiddleware("secret-key", collection)) + mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}", admin.AdminKeyMiddleware("secret-key", tenant)) + return mux, provisioner +} + +func TestCreateWorkspace(t *testing.T) { + reg := newTestRegistry(t) + tenantsDir := t.TempDir() + mux, _ := newTestTenantMux(reg, tenantsDir) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil { + t.Fatal(err) + } + + body := `{"display_name":"Acme Dental"}` + req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/tenants", bytes.NewBufferString(body)) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String()) + } + + var got registry.Tenant + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if got.AccountID != accountID { + t.Fatalf("account_id = %q, want %q", got.AccountID, accountID) + } + if got.DisplayName != "Acme Dental" { + t.Fatalf("display_name = %q, want %q", got.DisplayName, "Acme Dental") + } + + keyPath := filepath.Join(tenantsDir, got.ID, "secrets", "handoff.key") + info, err := os.Stat(keyPath) + if err != nil { + t.Fatalf("handoff.key missing: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Fatalf("handoff.key perms = %o, want %o", info.Mode().Perm(), 0o600) + } + b, err := os.ReadFile(keyPath) + if err != nil { + t.Fatalf("read handoff.key: %v", err) + } + if len(b) != 32 { + t.Fatalf("handoff.key size = %d, want 32", len(b)) + } +} + +func TestListWorkspaces(t *testing.T) { + reg := newTestRegistry(t) + tenantsDir := t.TempDir() + mux, provisioner := newTestTenantMux(reg, tenantsDir) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil { + t.Fatal(err) + } + + t1, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client One") + if err != nil { + t.Fatal(err) + } + t2, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client Two") + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/accounts/"+accountID+"/tenants", nil) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + + var got []*registry.Tenant + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(got) != 2 { + t.Fatalf("expected 2 tenants, got %d (%+v)", len(got), got) + } + + ids := map[string]bool{} + for _, tt := range got { + if tt.AccountID != accountID { + t.Fatalf("tenant account_id = %q, want %q", tt.AccountID, accountID) + } + ids[tt.ID] = true + } + if !ids[t1.ID] || !ids[t2.ID] { + t.Fatalf("missing ids: got=%v want=%q,%q", ids, t1.ID, t2.ID) + } +} + +func TestDeleteWorkspace(t *testing.T) { + reg := newTestRegistry(t) + tenantsDir := t.TempDir() + mux, provisioner := newTestTenantMux(reg, tenantsDir) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil { + t.Fatal(err) + } + + tenant, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client") + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/tenants/"+tenant.ID, nil) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusNoContent, rec.Body.String()) + } + + t2, err := reg.Get(tenant.ID) + if err != nil { + t.Fatal(err) + } + if t2 == nil { + t.Fatal("expected tenant to exist") + } + if t2.State != registry.TenantStateDeleted { + t.Fatalf("state = %q, want %q", t2.State, registry.TenantStateDeleted) + } +} + +func TestTenantBelongsToAccount(t *testing.T) { + reg := newTestRegistry(t) + tenantsDir := t.TempDir() + mux, provisioner := newTestTenantMux(reg, tenantsDir) + + account1, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + account2, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: account1, Kind: registry.AccountKindMSP, DisplayName: "A1"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: account2, Kind: registry.AccountKindMSP, DisplayName: "A2"}); err != nil { + t.Fatal(err) + } + + tenant, err := provisioner.ProvisionWorkspace(context.Background(), account1, "Client") + if err != nil { + t.Fatal(err) + } + + body := `{"display_name":"New Name"}` + req := httptest.NewRequest(http.MethodPatch, "/api/accounts/"+account2+"/tenants/"+tenant.ID, bytes.NewBufferString(body)) + rec := doRequest(t, mux, req) + + if rec.Code != http.StatusNotFound && rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want 404/403 (body=%q)", rec.Code, rec.Body.String()) + } +} diff --git a/internal/cloudcp/handoff/handler.go b/internal/cloudcp/handoff/handler.go new file mode 100644 index 000000000..ec15842c8 --- /dev/null +++ b/internal/cloudcp/handoff/handler.go @@ -0,0 +1,165 @@ +package handoff + +import ( + "html/template" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +const ( + handoffKeyRelPath = "secrets/handoff.key" +) + +var handoffHTMLTemplate = template.Must(template.New("handoff").Parse(` + +
+ +
+ + +`)) + +type handoffHTMLData struct { + TenantID string + BaseDomain string + Token string + GeneratedAt time.Time +} + +// HandleHandoff mints a tenant handoff token and returns an auto-submit HTML page. +// Route (wiring happens elsewhere): POST /api/accounts/{account_id}/tenants/{tenant_id}/handoff +// +// Auth: admin-key for now. User identity is provided by X-User-ID (temporary; session auth replaces this). +func HandleHandoff(reg *registry.TenantRegistry, tenantsDir string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if reg == nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + accountID := strings.TrimSpace(r.PathValue("account_id")) + tenantID := strings.TrimSpace(r.PathValue("tenant_id")) + if accountID == "" || tenantID == "" { + http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + t, err := reg.Get(tenantID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if t == nil || strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID { + http.Error(w, "tenant not found", http.StatusNotFound) + return + } + + // Temporary request-scoped identity until control plane session auth exists. + userID := strings.TrimSpace(r.Header.Get("X-User-ID")) + if userID == "" { + userID = strings.TrimSpace(r.Header.Get("X-User-Id")) + } + if userID == "" { + http.Error(w, "missing user identity", http.StatusBadRequest) + return + } + + m, err := reg.GetMembership(accountID, userID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if m == nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + u, err := reg.GetUser(userID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if u == nil || strings.TrimSpace(u.Email) == "" { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + keyPath := filepath.Join(filepath.Clean(tenantsDir), tenantID, filepath.FromSlash(handoffKeyRelPath)) + secret, err := os.ReadFile(keyPath) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + now := time.Now().UTC() + token, err := MintHandoffToken(secret, HandoffClaims{ + TenantID: tenantID, + UserID: userID, + AccountID: accountID, + Email: u.Email, + Role: m.Role, + IssuedAt: now, + ExpiresAt: now.Add(defaultTTL), + }) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + baseDomain := deriveBaseDomain(r) + if baseDomain == "" { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _ = handoffHTMLTemplate.Execute(w, handoffHTMLData{ + TenantID: tenantID, + BaseDomain: baseDomain, + Token: token, + GeneratedAt: now, + }) + } +} + +func deriveBaseDomain(r *http.Request) string { + if v := strings.TrimSpace(os.Getenv("CP_BASE_DOMAIN")); v != "" { + return v + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "" + } + // Strip port if present. + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } else { + // net.SplitHostPort errors if there's no port; that's fine. + if strings.Count(host, ":") > 1 { + // IPv6 without port isn't valid for our use case. + return "" + } + } + return host +} diff --git a/internal/cloudcp/handoff/handler_test.go b/internal/cloudcp/handoff/handler_test.go new file mode 100644 index 000000000..4ca6e75fc --- /dev/null +++ b/internal/cloudcp/handoff/handler_test.go @@ -0,0 +1,116 @@ +package handoff + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +func newTestRegistry(t *testing.T) *registry.TenantRegistry { + t.Helper() + dir := t.TempDir() + reg, err := registry.NewTenantRegistry(dir) + if err != nil { + t.Fatalf("NewTenantRegistry: %v", err) + } + t.Cleanup(func() { _ = reg.Close() }) + return reg +} + +func TestHandoffHandler(t *testing.T) { + reg := newTestRegistry(t) + tenantsDir := t.TempDir() + + accountID := "a_TEST" + tenantID := "t-TEST" + userID := "u_TEST" + + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(®istry.User{ID: userID, Email: "tech@example.com"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: userID, Role: registry.MemberRoleTech}); err != nil { + t.Fatal(err) + } + if err := reg.Create(®istry.Tenant{ID: tenantID, AccountID: accountID, DisplayName: "Client", State: registry.TenantStateActive}); err != nil { + t.Fatal(err) + } + + secret := []byte("0123456789abcdef0123456789abcdef") + keyPath := filepath.Join(tenantsDir, tenantID, "secrets", "handoff.key") + if err := os.MkdirAll(filepath.Dir(keyPath), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(keyPath, secret, 0o600); err != nil { + t.Fatal(err) + } + + mux := http.NewServeMux() + h := HandleHandoff(reg, tenantsDir) + mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}/handoff", admin.AdminKeyMiddleware("secret-key", h)) + + req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/tenants/"+tenantID+"/handoff", nil) + req.Host = "cloud.example.com" + req.Header.Set("X-Admin-Key", "secret-key") + req.Header.Set("X-User-ID", userID) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + body := rec.Body.String() + wantAction := "https://" + tenantID + ".cloud.example.com/api/cloud/handoff/exchange" + if !regexp.MustCompile(regexp.QuoteMeta(wantAction)).MatchString(body) { + t.Fatalf("missing form action %q in body", wantAction) + } + + re := regexp.MustCompile(`name="token" value="([^"]+)"`) + m := re.FindStringSubmatch(body) + if len(m) != 2 { + t.Fatalf("failed to extract token from HTML") + } + tokenStr := m[1] + + var got jwtHandoffClaims + parsed, err := jwt.ParseWithClaims( + tokenStr, + &got, + func(t *jwt.Token) (any, error) { return secret, nil }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(issuer), + jwt.WithAudience(tenantID), + ) + if err != nil { + t.Fatalf("ParseWithClaims: %v", err) + } + if !parsed.Valid { + t.Fatalf("token valid = false, want true") + } + if got.Subject != userID { + t.Fatalf("sub = %q, want %q", got.Subject, userID) + } + if got.AccountID != accountID { + t.Fatalf("account_id = %q, want %q", got.AccountID, accountID) + } + if got.Email != "tech@example.com" { + t.Fatalf("email = %q, want %q", got.Email, "tech@example.com") + } + if got.Role != registry.MemberRoleTech { + t.Fatalf("role = %q, want %q", got.Role, registry.MemberRoleTech) + } + if got.ExpiresAt == nil || time.Until(got.ExpiresAt.Time) > 60*time.Second+2*time.Second { + t.Fatalf("exp looks wrong: %v", got.ExpiresAt) + } +} diff --git a/internal/cloudcp/handoff/handoff.go b/internal/cloudcp/handoff/handoff.go new file mode 100644 index 000000000..a7963bad8 --- /dev/null +++ b/internal/cloudcp/handoff/handoff.go @@ -0,0 +1,131 @@ +package handoff + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +const ( + issuer = "pulse-cloud-control-plane" + defaultTTL = 60 * time.Second + handoffKeyLen = 32 +) + +// HandoffClaims are the logical claims that the control plane wants to assert +// when handing off an authenticated account user into a tenant container. +type HandoffClaims struct { + TenantID string + UserID string + AccountID string + Email string + Role registry.MemberRole + + IssuedAt time.Time + ExpiresAt time.Time + JTI string +} + +type jwtHandoffClaims struct { + AccountID string `json:"account_id"` + Email string `json:"email"` + Role registry.MemberRole `json:"role"` + jwt.RegisteredClaims +} + +// GenerateHandoffKey returns 32 cryptographically random bytes suitable for HS256 signing. +// Intended for writing to /data/tenants//secrets/handoff.key (0600). +func GenerateHandoffKey() ([]byte, error) { + key := make([]byte, handoffKeyLen) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return nil, fmt.Errorf("generate handoff key: %w", err) + } + return key, nil +} + +// MintHandoffToken mints a short-lived HS256 JWT signed with the per-tenant handoff.key. +// +// JWT registered claims: +// - iss: pulse-cloud-control-plane +// - aud: +// - sub: +// - iat, exp, jti +// +// Custom claims: +// - account_id, email, role +func MintHandoffToken(secret []byte, claims HandoffClaims) (string, error) { + if len(secret) == 0 { + return "", fmt.Errorf("secret is required") + } + claims.TenantID = sanitizeID(claims.TenantID) + claims.UserID = sanitizeID(claims.UserID) + claims.AccountID = sanitizeID(claims.AccountID) + if claims.TenantID == "" || claims.UserID == "" || claims.AccountID == "" { + return "", fmt.Errorf("tenantID, userID, and accountID are required") + } + if claims.Email == "" { + return "", fmt.Errorf("email is required") + } + + now := time.Now().UTC() + if claims.IssuedAt.IsZero() { + claims.IssuedAt = now + } + claims.IssuedAt = claims.IssuedAt.UTC() + + if claims.ExpiresAt.IsZero() { + claims.ExpiresAt = claims.IssuedAt.Add(defaultTTL) + } + claims.ExpiresAt = claims.ExpiresAt.UTC() + if !claims.ExpiresAt.After(claims.IssuedAt) { + return "", fmt.Errorf("expiresAt must be after issuedAt") + } + + if claims.JTI == "" { + jti, err := randomJTI128() + if err != nil { + return "", err + } + claims.JTI = jti + } + + jc := jwtHandoffClaims{ + AccountID: claims.AccountID, + Email: claims.Email, + Role: claims.Role, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuer, + Subject: claims.UserID, + Audience: jwt.ClaimStrings{claims.TenantID}, + IssuedAt: jwt.NewNumericDate(claims.IssuedAt), + ExpiresAt: jwt.NewNumericDate(claims.ExpiresAt), + ID: claims.JTI, + }, + } + + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jc) + signed, err := tok.SignedString(secret) + if err != nil { + return "", fmt.Errorf("sign jwt: %w", err) + } + return signed, nil +} + +func randomJTI128() (string, error) { + b := make([]byte, 16) // 128-bit + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", fmt.Errorf("generate jti: %w", err) + } + return hex.EncodeToString(b), nil +} + +func sanitizeID(s string) string { + // IDs are generated by the registry; trimming prevents surprising mismatches. + return strings.TrimSpace(s) +} diff --git a/internal/cloudcp/handoff/handoff_test.go b/internal/cloudcp/handoff/handoff_test.go new file mode 100644 index 000000000..42e692487 --- /dev/null +++ b/internal/cloudcp/handoff/handoff_test.go @@ -0,0 +1,167 @@ +package handoff + +import ( + "errors" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +func TestMintAndVerify(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") // 32 bytes + now := time.Now().UTC() + + tokenStr, err := MintHandoffToken(secret, HandoffClaims{ + TenantID: "t-TESTTENANT", + UserID: "u_TESTUSER", + AccountID: "a_TESTACCOUNT", + Email: "tech@example.com", + Role: registry.MemberRoleTech, + IssuedAt: now, + ExpiresAt: now.Add(60 * time.Second), + }) + if err != nil { + t.Fatalf("MintHandoffToken: %v", err) + } + + var got jwtHandoffClaims + parsed, err := jwt.ParseWithClaims( + tokenStr, + &got, + func(t *jwt.Token) (any, error) { return secret, nil }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(issuer), + jwt.WithAudience("t-TESTTENANT"), + ) + if err != nil { + t.Fatalf("ParseWithClaims: %v", err) + } + if !parsed.Valid { + t.Fatalf("token valid = false, want true") + } + + if got.Subject != "u_TESTUSER" { + t.Fatalf("sub = %q, want %q", got.Subject, "u_TESTUSER") + } + if got.AccountID != "a_TESTACCOUNT" { + t.Fatalf("account_id = %q, want %q", got.AccountID, "a_TESTACCOUNT") + } + if got.Email != "tech@example.com" { + t.Fatalf("email = %q, want %q", got.Email, "tech@example.com") + } + if got.Role != registry.MemberRoleTech { + t.Fatalf("role = %q, want %q", got.Role, registry.MemberRoleTech) + } + if got.ID == "" { + t.Fatalf("jti empty") + } + if got.IssuedAt == nil || got.ExpiresAt == nil { + t.Fatalf("missing iat/exp") + } + if got.ExpiresAt.Time.Sub(got.IssuedAt.Time) != 60*time.Second { + t.Fatalf("exp-iat = %v, want %v", got.ExpiresAt.Time.Sub(got.IssuedAt.Time), 60*time.Second) + } +} + +func TestExpiredToken(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + past := time.Now().UTC().Add(-2 * time.Minute) + + tokenStr, err := MintHandoffToken(secret, HandoffClaims{ + TenantID: "t-EXPIRED", + UserID: "u_USER", + AccountID: "a_ACCOUNT", + Email: "x@example.com", + Role: registry.MemberRoleReadOnly, + IssuedAt: past, + ExpiresAt: past.Add(30 * time.Second), + JTI: "deadbeefdeadbeefdeadbeefdeadbeef", + }) + if err != nil { + t.Fatalf("MintHandoffToken: %v", err) + } + + var got jwtHandoffClaims + _, err = jwt.ParseWithClaims( + tokenStr, + &got, + func(t *jwt.Token) (any, error) { return secret, nil }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(issuer), + jwt.WithAudience("t-EXPIRED"), + ) + if err == nil { + t.Fatalf("expected expired token error") + } + if !errors.Is(err, jwt.ErrTokenExpired) { + // Leave room for wrapped validation errors. + t.Fatalf("error = %v, want ErrTokenExpired", err) + } +} + +func TestWrongSecret(t *testing.T) { + secretA := []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + secretB := []byte("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + now := time.Now().UTC() + + tokenStr, err := MintHandoffToken(secretA, HandoffClaims{ + TenantID: "t-TENANT", + UserID: "u_USER", + AccountID: "a_ACCOUNT", + Email: "x@example.com", + Role: registry.MemberRoleAdmin, + IssuedAt: now, + ExpiresAt: now.Add(60 * time.Second), + JTI: "0123456789abcdef0123456789abcdef", + }) + if err != nil { + t.Fatalf("MintHandoffToken: %v", err) + } + + var got jwtHandoffClaims + _, err = jwt.ParseWithClaims( + tokenStr, + &got, + func(t *jwt.Token) (any, error) { return secretB, nil }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(issuer), + jwt.WithAudience("t-TENANT"), + ) + if err == nil { + t.Fatalf("expected signature verification failure") + } +} + +func TestWrongAudience(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + now := time.Now().UTC() + + tokenStr, err := MintHandoffToken(secret, HandoffClaims{ + TenantID: "t-A", + UserID: "u_USER", + AccountID: "a_ACCOUNT", + Email: "x@example.com", + Role: registry.MemberRoleOwner, + IssuedAt: now, + ExpiresAt: now.Add(60 * time.Second), + JTI: "0123456789abcdef0123456789abcdef", + }) + if err != nil { + t.Fatalf("MintHandoffToken: %v", err) + } + + var got jwtHandoffClaims + _, err = jwt.ParseWithClaims( + tokenStr, + &got, + func(t *jwt.Token) (any, error) { return secret, nil }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(issuer), + jwt.WithAudience("t-B"), + ) + if err == nil { + t.Fatalf("expected wrong audience error") + } +} diff --git a/internal/cloudcp/portal/handlers.go b/internal/cloudcp/portal/handlers.go new file mode 100644 index 000000000..208f079b9 --- /dev/null +++ b/internal/cloudcp/portal/handlers.go @@ -0,0 +1,202 @@ +package portal + +import ( + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +type accountInfo struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + Kind registry.AccountKind `json:"kind"` +} + +type workspaceSummaryItem struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + State registry.TenantState `json:"state"` + HealthCheckOK bool `json:"health_check_ok"` + LastHealthCheck *time.Time `json:"last_health_check"` + CreatedAt time.Time `json:"created_at"` +} + +type dashboardSummary struct { + Total int `json:"total"` + Active int `json:"active"` + Healthy int `json:"healthy"` + Unhealthy int `json:"unhealthy"` + Suspended int `json:"suspended"` +} + +type dashboardResponse struct { + Account accountInfo `json:"account"` + Workspaces []workspaceSummaryItem `json:"workspaces"` + Summary dashboardSummary `json:"summary"` +} + +type workspaceDetailResponse struct { + Account accountInfo `json:"account"` + Workspace *registry.Tenant `json:"workspace"` +} + +func accountIDFromRequest(r *http.Request) string { + if r == nil { + return "" + } + if v := strings.TrimSpace(r.URL.Query().Get("account_id")); v != "" { + return v + } + // Convenience for future callers; spec says query param is fine for now. + if v := strings.TrimSpace(r.Header.Get("X-Account-ID")); v != "" { + return v + } + if v := strings.TrimSpace(r.Header.Get("X-Account-Id")); v != "" { + return v + } + return "" +} + +// HandlePortalDashboard returns a portal-oriented dashboard response for an account. +// Route: GET /api/portal/dashboard?account_id=... +// +// Auth: admin-key for now (session auth in M-4). +func HandlePortalDashboard(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if reg == nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + accountID := accountIDFromRequest(r) + if accountID == "" { + http.Error(w, "missing account_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + tenants, err := reg.ListByAccountID(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if tenants == nil { + tenants = []*registry.Tenant{} + } + + resp := dashboardResponse{ + Account: accountInfo{ + ID: a.ID, + DisplayName: a.DisplayName, + Kind: a.Kind, + }, + Workspaces: make([]workspaceSummaryItem, 0, len(tenants)), + } + + for _, t := range tenants { + if t == nil { + continue + } + + resp.Workspaces = append(resp.Workspaces, workspaceSummaryItem{ + ID: t.ID, + DisplayName: t.DisplayName, + State: t.State, + HealthCheckOK: t.HealthCheckOK, + LastHealthCheck: t.LastHealthCheck, + CreatedAt: t.CreatedAt, + }) + + resp.Summary.Total++ + + switch t.State { + case registry.TenantStateActive: + resp.Summary.Active++ + if t.HealthCheckOK { + resp.Summary.Healthy++ + } else { + resp.Summary.Unhealthy++ + } + case registry.TenantStateSuspended: + resp.Summary.Suspended++ + } + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } +} + +// HandlePortalWorkspaceDetail returns a portal-oriented detail response for a single tenant. +// Route: GET /api/portal/workspaces/{tenant_id}?account_id=... +// +// Auth: admin-key for now (session auth in M-4). +func HandlePortalWorkspaceDetail(reg *registry.TenantRegistry) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if reg == nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + accountID := accountIDFromRequest(r) + tenantID := strings.TrimSpace(r.PathValue("tenant_id")) + if accountID == "" || tenantID == "" { + http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest) + return + } + + a, err := reg.GetAccount(accountID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if a == nil { + http.Error(w, "account not found", http.StatusNotFound) + return + } + + t, err := reg.Get(tenantID) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + if t == nil || strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID { + http.Error(w, "tenant not found", http.StatusNotFound) + return + } + + resp := workspaceDetailResponse{ + Account: accountInfo{ + ID: a.ID, + DisplayName: a.DisplayName, + Kind: a.Kind, + }, + Workspace: t, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } +} diff --git a/internal/cloudcp/portal/handlers_test.go b/internal/cloudcp/portal/handlers_test.go new file mode 100644 index 000000000..35ce93198 --- /dev/null +++ b/internal/cloudcp/portal/handlers_test.go @@ -0,0 +1,310 @@ +package portal + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sort" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" +) + +func newTestRegistry(t *testing.T) *registry.TenantRegistry { + t.Helper() + dir := t.TempDir() + reg, err := registry.NewTenantRegistry(dir) + if err != nil { + t.Fatalf("NewTenantRegistry: %v", err) + } + t.Cleanup(func() { _ = reg.Close() }) + return reg +} + +func newTestMux(reg *registry.TenantRegistry) *http.ServeMux { + mux := http.NewServeMux() + mux.Handle("/api/portal/dashboard", admin.AdminKeyMiddleware("secret-key", HandlePortalDashboard(reg))) + mux.Handle("/api/portal/workspaces/{tenant_id}", admin.AdminKeyMiddleware("secret-key", HandlePortalWorkspaceDetail(reg))) + return mux +} + +func doRequest(t *testing.T, h http.Handler, req *http.Request) *httptest.ResponseRecorder { + t.Helper() + req.Header.Set("X-Admin-Key", "secret-key") + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + return rec +} + +type dashboardResp struct { + Account struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + Kind registry.AccountKind `json:"kind"` + } `json:"account"` + Workspaces []struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + State registry.TenantState `json:"state"` + HealthCheckOK bool `json:"health_check_ok"` + LastHealthCheck *time.Time `json:"last_health_check"` + CreatedAt time.Time `json:"created_at"` + } `json:"workspaces"` + Summary struct { + Total int `json:"total"` + Active int `json:"active"` + Healthy int `json:"healthy"` + Unhealthy int `json:"unhealthy"` + Suspended int `json:"suspended"` + } `json:"summary"` +} + +func TestPortalDashboard(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Example MSP"}); err != nil { + t.Fatal(err) + } + + tenantActiveID, err := registry.GenerateTenantID() + if err != nil { + t.Fatal(err) + } + tenantSuspendedID, err := registry.GenerateTenantID() + if err != nil { + t.Fatal(err) + } + + created1 := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC) + created2 := time.Date(2026, 2, 10, 11, 0, 0, 0, time.UTC) + lastCheck := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC) + + if err := reg.Create(®istry.Tenant{ + ID: tenantActiveID, + AccountID: accountID, + DisplayName: "Acme Dental", + State: registry.TenantStateActive, + CreatedAt: created1, + LastHealthCheck: &lastCheck, + HealthCheckOK: true, + }); err != nil { + t.Fatal(err) + } + if err := reg.Create(®istry.Tenant{ + ID: tenantSuspendedID, + AccountID: accountID, + DisplayName: "Suspended Workspace", + State: registry.TenantStateSuspended, + CreatedAt: created2, + HealthCheckOK: false, + }); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/portal/dashboard?account_id="+accountID, nil) + rec := doRequest(t, mux, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp dashboardResp + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v (body=%q)", err, rec.Body.String()) + } + + if resp.Account.ID != accountID { + t.Fatalf("account.id = %q, want %q", resp.Account.ID, accountID) + } + if resp.Account.DisplayName != "Example MSP" { + t.Fatalf("account.display_name = %q, want %q", resp.Account.DisplayName, "Example MSP") + } + if resp.Account.Kind != registry.AccountKindMSP { + t.Fatalf("account.kind = %q, want %q", resp.Account.Kind, registry.AccountKindMSP) + } + + if len(resp.Workspaces) != 2 { + t.Fatalf("workspaces len = %d, want %d", len(resp.Workspaces), 2) + } + + // Make assertions order-independent. + sort.Slice(resp.Workspaces, func(i, j int) bool { return resp.Workspaces[i].ID < resp.Workspaces[j].ID }) + wsByID := map[string]dashboardRespWorkspace{} + + // local helper type for easier indexing + for _, ws := range resp.Workspaces { + wsByID[ws.ID] = dashboardRespWorkspace{ + ID: ws.ID, + DisplayName: ws.DisplayName, + State: ws.State, + HealthCheckOK: ws.HealthCheckOK, + LastHealthCheck: ws.LastHealthCheck, + CreatedAt: ws.CreatedAt, + } + } + + active := wsByID[tenantActiveID] + if active.ID == "" { + t.Fatalf("missing active workspace id %q", tenantActiveID) + } + if active.DisplayName != "Acme Dental" { + t.Fatalf("active.display_name = %q, want %q", active.DisplayName, "Acme Dental") + } + if active.State != registry.TenantStateActive { + t.Fatalf("active.state = %q, want %q", active.State, registry.TenantStateActive) + } + if !active.HealthCheckOK { + t.Fatalf("active.health_check_ok = false, want true") + } + if active.LastHealthCheck == nil || !active.LastHealthCheck.Equal(lastCheck) { + t.Fatalf("active.last_health_check = %v, want %v", active.LastHealthCheck, lastCheck) + } + if !active.CreatedAt.Equal(created1) { + t.Fatalf("active.created_at = %v, want %v", active.CreatedAt, created1) + } + + susp := wsByID[tenantSuspendedID] + if susp.ID == "" { + t.Fatalf("missing suspended workspace id %q", tenantSuspendedID) + } + if susp.State != registry.TenantStateSuspended { + t.Fatalf("suspended.state = %q, want %q", susp.State, registry.TenantStateSuspended) + } + + if resp.Summary.Total != 2 { + t.Fatalf("summary.total = %d, want %d", resp.Summary.Total, 2) + } + if resp.Summary.Active != 1 { + t.Fatalf("summary.active = %d, want %d", resp.Summary.Active, 1) + } + if resp.Summary.Healthy != 1 { + t.Fatalf("summary.healthy = %d, want %d", resp.Summary.Healthy, 1) + } + if resp.Summary.Unhealthy != 0 { + t.Fatalf("summary.unhealthy = %d, want %d", resp.Summary.Unhealthy, 0) + } + if resp.Summary.Suspended != 1 { + t.Fatalf("summary.suspended = %d, want %d", resp.Summary.Suspended, 1) + } +} + +type dashboardRespWorkspace struct { + ID string + DisplayName string + State registry.TenantState + HealthCheckOK bool + LastHealthCheck *time.Time + CreatedAt time.Time +} + +func TestPortalDashboardEmpty(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Empty MSP"}); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/portal/dashboard?account_id="+accountID, nil) + rec := doRequest(t, mux, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp dashboardResp + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v (body=%q)", err, rec.Body.String()) + } + + if len(resp.Workspaces) != 0 { + t.Fatalf("workspaces len = %d, want %d", len(resp.Workspaces), 0) + } + if resp.Summary.Total != 0 || resp.Summary.Active != 0 || resp.Summary.Healthy != 0 || resp.Summary.Unhealthy != 0 || resp.Summary.Suspended != 0 { + t.Fatalf("summary = %+v, want all zeros", resp.Summary) + } +} + +func TestPortalWorkspaceDetail(t *testing.T) { + reg := newTestRegistry(t) + mux := newTestMux(reg) + + accountID, err := registry.GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Example MSP"}); err != nil { + t.Fatal(err) + } + + tenantID, err := registry.GenerateTenantID() + if err != nil { + t.Fatal(err) + } + + created := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC) + lastCheck := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC) + if err := reg.Create(®istry.Tenant{ + ID: tenantID, + AccountID: accountID, + DisplayName: "Acme Dental", + State: registry.TenantStateActive, + CreatedAt: created, + LastHealthCheck: &lastCheck, + HealthCheckOK: true, + }); err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/portal/workspaces/"+tenantID+"?account_id="+accountID, nil) + rec := doRequest(t, mux, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Account struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + Kind registry.AccountKind `json:"kind"` + } `json:"account"` + Workspace registry.Tenant `json:"workspace"` + } + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v (body=%q)", err, rec.Body.String()) + } + + if resp.Account.ID != accountID { + t.Fatalf("account.id = %q, want %q", resp.Account.ID, accountID) + } + if resp.Workspace.ID != tenantID { + t.Fatalf("workspace.id = %q, want %q", resp.Workspace.ID, tenantID) + } + if resp.Workspace.AccountID != accountID { + t.Fatalf("workspace.account_id = %q, want %q", resp.Workspace.AccountID, accountID) + } + if resp.Workspace.DisplayName != "Acme Dental" { + t.Fatalf("workspace.display_name = %q, want %q", resp.Workspace.DisplayName, "Acme Dental") + } + if resp.Workspace.State != registry.TenantStateActive { + t.Fatalf("workspace.state = %q, want %q", resp.Workspace.State, registry.TenantStateActive) + } + if !resp.Workspace.HealthCheckOK { + t.Fatalf("workspace.health_check_ok = false, want true") + } + if resp.Workspace.LastHealthCheck == nil || !resp.Workspace.LastHealthCheck.Equal(lastCheck) { + t.Fatalf("workspace.last_health_check = %v, want %v", resp.Workspace.LastHealthCheck, lastCheck) + } + if !resp.Workspace.CreatedAt.Equal(created) { + t.Fatalf("workspace.created_at = %v, want %v", resp.Workspace.CreatedAt, created) + } +} diff --git a/internal/cloudcp/registry/models.go b/internal/cloudcp/registry/models.go index 026d930b0..c35ea142b 100644 --- a/internal/cloudcp/registry/models.go +++ b/internal/cloudcp/registry/models.go @@ -7,6 +7,47 @@ import ( "time" ) +type AccountKind string + +const ( + AccountKindIndividual AccountKind = "individual" + AccountKindMSP AccountKind = "msp" +) + +// Account represents an account record in the control plane registry. +type Account struct { + ID string `json:"id"` + Kind AccountKind `json:"kind"` + DisplayName string `json:"display_name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// User represents a user record in the control plane registry. +type User struct { + ID string `json:"id"` + Email string `json:"email"` + CreatedAt time.Time `json:"created_at"` + LastLoginAt *time.Time `json:"last_login_at,omitempty"` +} + +type MemberRole string + +const ( + MemberRoleOwner MemberRole = "owner" + MemberRoleAdmin MemberRole = "admin" + MemberRoleTech MemberRole = "tech" + MemberRoleReadOnly MemberRole = "read_only" +) + +// AccountMembership represents a mapping between an account and a user. +type AccountMembership struct { + AccountID string `json:"account_id"` + UserID string `json:"user_id"` + Role MemberRole `json:"role"` + CreatedAt time.Time `json:"created_at"` +} + // TenantState represents the lifecycle state of a tenant. type TenantState string @@ -15,12 +56,14 @@ const ( TenantStateActive TenantState = "active" TenantStateSuspended TenantState = "suspended" TenantStateCanceled TenantState = "canceled" + TenantStateDeleting TenantState = "deleting" TenantStateDeleted TenantState = "deleted" ) // Tenant represents a Cloud tenant record in the registry. type Tenant struct { ID string `json:"id"` + AccountID string `json:"account_id"` Email string `json:"email"` DisplayName string `json:"display_name"` State TenantState `json:"state"` @@ -37,6 +80,20 @@ type Tenant struct { HealthCheckOK bool `json:"health_check_ok"` } +// StripeAccount maps a control-plane account to a single Stripe customer + +// subscription for consolidated (MSP-style) billing. +type StripeAccount struct { + AccountID string `json:"account_id"` + StripeCustomerID string `json:"stripe_customer_id"` + StripeSubscriptionID string `json:"stripe_subscription_id"` + StripeSubItemWorkspacesID string `json:"stripe_sub_item_workspaces_id"` + PlanVersion string `json:"plan_version"` + SubscriptionState string `json:"subscription_state"` // trial, active, past_due, canceled + TrialEndsAt *int64 `json:"trial_ends_at"` + CurrentPeriodEnd *int64 `json:"current_period_end"` + UpdatedAt int64 `json:"updated_at"` +} + // crockfordBase32 is the Crockford base32 alphabet (excludes I, L, O, U). const crockfordBase32 = "0123456789ABCDEFGHJKMNPQRSTVWXYZ" @@ -54,3 +111,33 @@ func GenerateTenantID() (string, error) { } return sb.String(), nil } + +// GenerateAccountID returns an account ID of the form "a_" followed by 10 random +// Crockford base32 characters (50 bits of entropy). +func GenerateAccountID() (string, error) { + b := make([]byte, 10) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate account id: %w", err) + } + var sb strings.Builder + sb.WriteString("a_") + for _, v := range b { + sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)]) + } + return sb.String(), nil +} + +// GenerateUserID returns a user ID of the form "u_" followed by 10 random +// Crockford base32 characters (50 bits of entropy). +func GenerateUserID() (string, error) { + b := make([]byte, 10) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate user id: %w", err) + } + var sb strings.Builder + sb.WriteString("u_") + for _, v := range b { + sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)]) + } + return sb.String(), nil +} diff --git a/internal/cloudcp/registry/registry.go b/internal/cloudcp/registry/registry.go index d7d0a95f7..d5e92056d 100644 --- a/internal/cloudcp/registry/registry.go +++ b/internal/cloudcp/registry/registry.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "path/filepath" + "strings" "time" _ "modernc.org/sqlite" @@ -51,6 +52,7 @@ func (r *TenantRegistry) initSchema() error { schema := ` CREATE TABLE IF NOT EXISTS tenants ( id TEXT PRIMARY KEY, + account_id TEXT NOT NULL DEFAULT '', email TEXT NOT NULL DEFAULT '', display_name TEXT NOT NULL DEFAULT '', state TEXT NOT NULL DEFAULT 'provisioning', @@ -68,13 +70,106 @@ func (r *TenantRegistry) initSchema() error { ); CREATE INDEX IF NOT EXISTS idx_tenants_state ON tenants(state); CREATE INDEX IF NOT EXISTS idx_tenants_stripe_customer_id ON tenants(stripe_customer_id); + + CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + kind TEXT NOT NULL DEFAULT 'individual', + display_name TEXT NOT NULL DEFAULT '', + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS stripe_accounts ( + account_id TEXT PRIMARY KEY, + stripe_customer_id TEXT NOT NULL UNIQUE, + stripe_subscription_id TEXT, + stripe_sub_item_workspaces_id TEXT, + plan_version TEXT NOT NULL DEFAULT '', + subscription_state TEXT NOT NULL DEFAULT 'trial', + trial_ends_at INTEGER, + current_period_end INTEGER, + updated_at INTEGER NOT NULL, + FOREIGN KEY (account_id) REFERENCES accounts(id) + ); + CREATE INDEX IF NOT EXISTS idx_stripe_accounts_customer ON stripe_accounts(stripe_customer_id); + + CREATE TABLE IF NOT EXISTS stripe_events ( + stripe_event_id TEXT PRIMARY KEY, + event_type TEXT NOT NULL, + received_at INTEGER NOT NULL, + processed_at INTEGER, + processing_error TEXT + ); + + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL UNIQUE, + created_at INTEGER NOT NULL, + last_login_at INTEGER + ); + + CREATE TABLE IF NOT EXISTS account_memberships ( + account_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'tech', + created_at INTEGER NOT NULL, + PRIMARY KEY (account_id, user_id), + FOREIGN KEY (account_id) REFERENCES accounts(id), + FOREIGN KEY (user_id) REFERENCES users(id) + ); + CREATE INDEX IF NOT EXISTS idx_memberships_user_id ON account_memberships(user_id); ` if _, err := r.db.Exec(schema); err != nil { return fmt.Errorf("init tenant registry schema: %w", err) } + + // Migration: add account_id to tenants if not present. + // (SQLite makes it awkward to add FK constraints via ALTER TABLE, and FK + // enforcement is off by default; this keeps the change backwards-compatible.) + hasAccountID, err := r.tenantsHasColumn("account_id") + if err != nil { + return err + } + if !hasAccountID { + if _, err := r.db.Exec(`ALTER TABLE tenants ADD COLUMN account_id TEXT NOT NULL DEFAULT ''`); err != nil { + return fmt.Errorf("migrate tenants: add account_id: %w", err) + } + } + if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_tenants_account_id ON tenants(account_id)`); err != nil { + return fmt.Errorf("init tenant registry schema: create idx_tenants_account_id: %w", err) + } return nil } +func (r *TenantRegistry) tenantsHasColumn(name string) (bool, error) { + rows, err := r.db.Query(`PRAGMA table_info(tenants)`) + if err != nil { + return false, fmt.Errorf("pragma table_info(tenants): %w", err) + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + colName string + colType string + notNull int + dflt sql.NullString + pk int + ) + if err := rows.Scan(&cid, &colName, &colType, ¬Null, &dflt, &pk); err != nil { + return false, fmt.Errorf("scan table_info(tenants): %w", err) + } + if colName == name { + return true, nil + } + } + if err := rows.Err(); err != nil { + return false, fmt.Errorf("iterate table_info(tenants): %w", err) + } + return false, nil +} + // Ping checks database connectivity (used for readiness probes). func (r *TenantRegistry) Ping() error { return r.db.Ping() @@ -101,12 +196,12 @@ func (r *TenantRegistry) Create(t *Tenant) error { _, err := r.db.Exec(` INSERT INTO tenants ( - id, email, display_name, state, + id, account_id, email, display_name, state, stripe_customer_id, stripe_subscription_id, stripe_price_id, plan_version, container_id, current_image_digest, desired_image_digest, created_at, updated_at, last_health_check, health_check_ok - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - t.ID, t.Email, t.DisplayName, string(t.State), + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + t.ID, t.AccountID, t.Email, t.DisplayName, string(t.State), t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID, t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest, t.CreatedAt.Unix(), t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK), @@ -120,7 +215,7 @@ func (r *TenantRegistry) Create(t *Tenant) error { // Get retrieves a tenant by ID. func (r *TenantRegistry) Get(id string) (*Tenant, error) { row := r.db.QueryRow(`SELECT - id, email, display_name, state, + id, account_id, email, display_name, state, stripe_customer_id, stripe_subscription_id, stripe_price_id, plan_version, container_id, current_image_digest, desired_image_digest, created_at, updated_at, last_health_check, health_check_ok @@ -131,7 +226,7 @@ func (r *TenantRegistry) Get(id string) (*Tenant, error) { // GetByStripeCustomerID retrieves a tenant by Stripe customer ID. func (r *TenantRegistry) GetByStripeCustomerID(customerID string) (*Tenant, error) { row := r.db.QueryRow(`SELECT - id, email, display_name, state, + id, account_id, email, display_name, state, stripe_customer_id, stripe_subscription_id, stripe_price_id, plan_version, container_id, current_image_digest, desired_image_digest, created_at, updated_at, last_health_check, health_check_ok @@ -148,12 +243,12 @@ func (r *TenantRegistry) Update(t *Tenant) error { res, err := r.db.Exec(` UPDATE tenants SET - email = ?, display_name = ?, state = ?, + account_id = ?, email = ?, display_name = ?, state = ?, stripe_customer_id = ?, stripe_subscription_id = ?, stripe_price_id = ?, plan_version = ?, container_id = ?, current_image_digest = ?, desired_image_digest = ?, updated_at = ?, last_health_check = ?, health_check_ok = ? WHERE id = ?`, - t.Email, t.DisplayName, string(t.State), + t.AccountID, t.Email, t.DisplayName, string(t.State), t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID, t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest, t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK), @@ -172,7 +267,7 @@ func (r *TenantRegistry) Update(t *Tenant) error { // List returns all tenants. func (r *TenantRegistry) List() ([]*Tenant, error) { rows, err := r.db.Query(`SELECT - id, email, display_name, state, + id, account_id, email, display_name, state, stripe_customer_id, stripe_subscription_id, stripe_price_id, plan_version, container_id, current_image_digest, desired_image_digest, created_at, updated_at, last_health_check, health_check_ok @@ -187,7 +282,7 @@ func (r *TenantRegistry) List() ([]*Tenant, error) { // ListByState returns all tenants matching the given state. func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) { rows, err := r.db.Query(`SELECT - id, email, display_name, state, + id, account_id, email, display_name, state, stripe_customer_id, stripe_subscription_id, stripe_price_id, plan_version, container_id, current_image_digest, desired_image_digest, created_at, updated_at, last_health_check, health_check_ok @@ -199,6 +294,21 @@ func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) { return scanTenants(rows) } +// ListByAccountID returns all tenants belonging to the given account ID. +func (r *TenantRegistry) ListByAccountID(accountID string) ([]*Tenant, error) { + rows, err := r.db.Query(`SELECT + id, account_id, email, display_name, state, + stripe_customer_id, stripe_subscription_id, stripe_price_id, + plan_version, container_id, current_image_digest, desired_image_digest, + created_at, updated_at, last_health_check, health_check_ok + FROM tenants WHERE account_id = ? ORDER BY created_at DESC`, accountID) + if err != nil { + return nil, fmt.Errorf("list tenants by account id: %w", err) + } + defer rows.Close() + return scanTenants(rows) +} + // CountByState returns a map of state -> count. func (r *TenantRegistry) CountByState() (map[TenantState]int, error) { rows, err := r.db.Query(`SELECT state, COUNT(*) FROM tenants GROUP BY state`) @@ -244,7 +354,7 @@ func scanTenant(s scanner) (*Tenant, error) { var healthOK int err := s.Scan( - &t.ID, &t.Email, &t.DisplayName, &state, + &t.ID, &t.AccountID, &t.Email, &t.DisplayName, &state, &t.StripeCustomerID, &t.StripeSubscriptionID, &t.StripePriceID, &t.PlanVersion, &t.ContainerID, &t.CurrentImageDigest, &t.DesiredImageDigest, &createdAt, &updatedAt, &lastHealthCheck, &healthOK, @@ -279,6 +389,474 @@ func scanTenants(rows *sql.Rows) ([]*Tenant, error) { return tenants, rows.Err() } +// CreateAccount inserts a new account record. +func (r *TenantRegistry) CreateAccount(a *Account) error { + if a == nil { + return fmt.Errorf("account is nil") + } + now := time.Now().UTC() + if a.CreatedAt.IsZero() { + a.CreatedAt = now + } + a.UpdatedAt = now + + kind := string(a.Kind) + if kind == "" { + kind = string(AccountKindIndividual) + } + + _, err := r.db.Exec(` + INSERT INTO accounts ( + id, kind, display_name, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?)`, + a.ID, kind, a.DisplayName, a.CreatedAt.Unix(), a.UpdatedAt.Unix(), + ) + if err != nil { + return fmt.Errorf("create account: %w", err) + } + a.Kind = AccountKind(kind) + return nil +} + +// GetAccount retrieves an account by ID. +func (r *TenantRegistry) GetAccount(id string) (*Account, error) { + row := r.db.QueryRow(`SELECT + id, kind, display_name, created_at, updated_at + FROM accounts WHERE id = ?`, id) + return scanAccount(row) +} + +// UpdateAccount modifies an existing account record. +func (r *TenantRegistry) UpdateAccount(a *Account) error { + if a == nil { + return fmt.Errorf("account is nil") + } + a.UpdatedAt = time.Now().UTC() + + kind := string(a.Kind) + if kind == "" { + kind = string(AccountKindIndividual) + } + + res, err := r.db.Exec(` + UPDATE accounts SET + kind = ?, display_name = ?, updated_at = ? + WHERE id = ?`, + kind, a.DisplayName, a.UpdatedAt.Unix(), + a.ID, + ) + if err != nil { + return fmt.Errorf("update account: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return fmt.Errorf("account %q not found", a.ID) + } + a.Kind = AccountKind(kind) + return nil +} + +// ListAccounts returns all accounts. +func (r *TenantRegistry) ListAccounts() ([]*Account, error) { + rows, err := r.db.Query(`SELECT + id, kind, display_name, created_at, updated_at + FROM accounts ORDER BY created_at DESC`) + if err != nil { + return nil, fmt.Errorf("list accounts: %w", err) + } + defer rows.Close() + return scanAccounts(rows) +} + +// CreateUser inserts a new user record. +func (r *TenantRegistry) CreateUser(u *User) error { + if u == nil { + return fmt.Errorf("user is nil") + } + now := time.Now().UTC() + if u.CreatedAt.IsZero() { + u.CreatedAt = now + } + + _, err := r.db.Exec(` + INSERT INTO users ( + id, email, created_at, last_login_at + ) VALUES (?, ?, ?, ?)`, + u.ID, u.Email, u.CreatedAt.Unix(), nullableTimeUnix(u.LastLoginAt), + ) + if err != nil { + return fmt.Errorf("create user: %w", err) + } + return nil +} + +// GetUser retrieves a user by ID. +func (r *TenantRegistry) GetUser(id string) (*User, error) { + row := r.db.QueryRow(`SELECT + id, email, created_at, last_login_at + FROM users WHERE id = ?`, id) + return scanUser(row) +} + +// GetUserByEmail retrieves a user by email. +func (r *TenantRegistry) GetUserByEmail(email string) (*User, error) { + row := r.db.QueryRow(`SELECT + id, email, created_at, last_login_at + FROM users WHERE email = ?`, email) + return scanUser(row) +} + +// UpdateUserLastLogin sets last_login_at for the given user ID to the current time. +func (r *TenantRegistry) UpdateUserLastLogin(id string) error { + now := time.Now().UTC() + res, err := r.db.Exec(`UPDATE users SET last_login_at = ? WHERE id = ?`, now.Unix(), id) + if err != nil { + return fmt.Errorf("update user last login: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return fmt.Errorf("user %q not found", id) + } + return nil +} + +// CreateMembership inserts a new membership record. +func (r *TenantRegistry) CreateMembership(m *AccountMembership) error { + if m == nil { + return fmt.Errorf("membership is nil") + } + now := time.Now().UTC() + if m.CreatedAt.IsZero() { + m.CreatedAt = now + } + role := string(m.Role) + if role == "" { + role = string(MemberRoleTech) + } + + _, err := r.db.Exec(` + INSERT INTO account_memberships ( + account_id, user_id, role, created_at + ) VALUES (?, ?, ?, ?)`, + m.AccountID, m.UserID, role, m.CreatedAt.Unix(), + ) + if err != nil { + return fmt.Errorf("create membership: %w", err) + } + m.Role = MemberRole(role) + return nil +} + +// GetMembership retrieves a membership record by account ID and user ID. +func (r *TenantRegistry) GetMembership(accountID, userID string) (*AccountMembership, error) { + row := r.db.QueryRow(`SELECT + account_id, user_id, role, created_at + FROM account_memberships + WHERE account_id = ? AND user_id = ?`, accountID, userID) + return scanMembership(row) +} + +// ListMembersByAccount returns all membership records for a given account ID. +func (r *TenantRegistry) ListMembersByAccount(accountID string) ([]*AccountMembership, error) { + rows, err := r.db.Query(`SELECT + account_id, user_id, role, created_at + FROM account_memberships + WHERE account_id = ? + ORDER BY created_at DESC`, accountID) + if err != nil { + return nil, fmt.Errorf("list members by account: %w", err) + } + defer rows.Close() + return scanMemberships(rows) +} + +// ListAccountsByUser returns account IDs for all accounts the given user belongs to. +func (r *TenantRegistry) ListAccountsByUser(userID string) ([]string, error) { + rows, err := r.db.Query(`SELECT account_id FROM account_memberships WHERE user_id = ? ORDER BY created_at DESC`, userID) + if err != nil { + return nil, fmt.Errorf("list accounts by user: %w", err) + } + defer rows.Close() + + var accountIDs []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan account id: %w", err) + } + accountIDs = append(accountIDs, id) + } + return accountIDs, rows.Err() +} + +// UpdateMembershipRole updates a membership role. +func (r *TenantRegistry) UpdateMembershipRole(accountID, userID string, role MemberRole) error { + res, err := r.db.Exec(`UPDATE account_memberships SET role = ? WHERE account_id = ? AND user_id = ?`, string(role), accountID, userID) + if err != nil { + return fmt.Errorf("update membership role: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return fmt.Errorf("membership (%q, %q) not found", accountID, userID) + } + return nil +} + +// DeleteMembership deletes a membership record. +func (r *TenantRegistry) DeleteMembership(accountID, userID string) error { + res, err := r.db.Exec(`DELETE FROM account_memberships WHERE account_id = ? AND user_id = ?`, accountID, userID) + if err != nil { + return fmt.Errorf("delete membership: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return fmt.Errorf("membership (%q, %q) not found", accountID, userID) + } + return nil +} + +// CreateStripeAccount inserts a new StripeAccount mapping row. +func (r *TenantRegistry) CreateStripeAccount(sa *StripeAccount) error { + if sa == nil { + return fmt.Errorf("stripe account is nil") + } + sa.AccountID = strings.TrimSpace(sa.AccountID) + sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID) + sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID) + sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID) + sa.PlanVersion = strings.TrimSpace(sa.PlanVersion) + sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState) + + if sa.AccountID == "" { + return fmt.Errorf("missing account id") + } + if sa.StripeCustomerID == "" { + return fmt.Errorf("missing stripe customer id") + } + if sa.SubscriptionState == "" { + sa.SubscriptionState = "trial" + } + if sa.UpdatedAt == 0 { + sa.UpdatedAt = time.Now().UTC().Unix() + } + + _, err := r.db.Exec(` + INSERT INTO stripe_accounts ( + account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id, + plan_version, subscription_state, trial_ends_at, current_period_end, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + sa.AccountID, + sa.StripeCustomerID, + nullableString(sa.StripeSubscriptionID), + nullableString(sa.StripeSubItemWorkspacesID), + sa.PlanVersion, + sa.SubscriptionState, + nullableInt64Ptr(sa.TrialEndsAt), + nullableInt64Ptr(sa.CurrentPeriodEnd), + sa.UpdatedAt, + ) + if err != nil { + return fmt.Errorf("create stripe account: %w", err) + } + return nil +} + +// GetStripeAccount retrieves the StripeAccount row by account ID. +func (r *TenantRegistry) GetStripeAccount(accountID string) (*StripeAccount, error) { + row := r.db.QueryRow(`SELECT + account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id, + plan_version, subscription_state, trial_ends_at, current_period_end, updated_at + FROM stripe_accounts WHERE account_id = ?`, strings.TrimSpace(accountID)) + return scanStripeAccount(row) +} + +// GetStripeAccountByCustomerID retrieves the StripeAccount row by Stripe customer ID. +func (r *TenantRegistry) GetStripeAccountByCustomerID(customerID string) (*StripeAccount, error) { + row := r.db.QueryRow(`SELECT + account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id, + plan_version, subscription_state, trial_ends_at, current_period_end, updated_at + FROM stripe_accounts WHERE stripe_customer_id = ?`, strings.TrimSpace(customerID)) + return scanStripeAccount(row) +} + +// UpdateStripeAccount modifies an existing StripeAccount row. +func (r *TenantRegistry) UpdateStripeAccount(sa *StripeAccount) error { + if sa == nil { + return fmt.Errorf("stripe account is nil") + } + sa.AccountID = strings.TrimSpace(sa.AccountID) + sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID) + sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID) + sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID) + sa.PlanVersion = strings.TrimSpace(sa.PlanVersion) + sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState) + + if sa.AccountID == "" { + return fmt.Errorf("missing account id") + } + if sa.StripeCustomerID == "" { + return fmt.Errorf("missing stripe customer id") + } + if sa.SubscriptionState == "" { + sa.SubscriptionState = "trial" + } + + sa.UpdatedAt = time.Now().UTC().Unix() + + res, err := r.db.Exec(` + UPDATE stripe_accounts SET + stripe_customer_id = ?, stripe_subscription_id = ?, stripe_sub_item_workspaces_id = ?, + plan_version = ?, subscription_state = ?, trial_ends_at = ?, current_period_end = ?, updated_at = ? + WHERE account_id = ?`, + sa.StripeCustomerID, + nullableString(sa.StripeSubscriptionID), + nullableString(sa.StripeSubItemWorkspacesID), + sa.PlanVersion, + sa.SubscriptionState, + nullableInt64Ptr(sa.TrialEndsAt), + nullableInt64Ptr(sa.CurrentPeriodEnd), + sa.UpdatedAt, + sa.AccountID, + ) + if err != nil { + return fmt.Errorf("update stripe account: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return fmt.Errorf("stripe account %q not found", sa.AccountID) + } + return nil +} + +// RecordStripeEvent inserts a webhook event record and returns true if the +// event was already recorded (duplicate Stripe delivery). +func (r *TenantRegistry) RecordStripeEvent(eventID, eventType string) (alreadyProcessed bool, err error) { + eventID = strings.TrimSpace(eventID) + eventType = strings.TrimSpace(eventType) + if eventID == "" { + return false, fmt.Errorf("missing stripe event id") + } + if eventType == "" { + return false, fmt.Errorf("missing stripe event type") + } + + // INSERT OR IGNORE avoids driver-specific error parsing for duplicates. + res, err := r.db.Exec(` + INSERT OR IGNORE INTO stripe_events ( + stripe_event_id, event_type, received_at, processed_at, processing_error + ) VALUES (?, ?, ?, NULL, NULL)`, + eventID, eventType, time.Now().UTC().Unix(), + ) + if err != nil { + return false, fmt.Errorf("record stripe event: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return true, nil + } + return false, nil +} + +// MarkStripeEventProcessed marks a previously recorded event as processed. +// processingError is stored (nullable) for troubleshooting. +func (r *TenantRegistry) MarkStripeEventProcessed(eventID string, processingError string) error { + eventID = strings.TrimSpace(eventID) + if eventID == "" { + return fmt.Errorf("missing stripe event id") + } + processingError = strings.TrimSpace(processingError) + + res, err := r.db.Exec(` + UPDATE stripe_events SET + processed_at = ?, processing_error = ? + WHERE stripe_event_id = ?`, + time.Now().UTC().Unix(), + nullableString(processingError), + eventID, + ) + if err != nil { + return fmt.Errorf("mark stripe event processed: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return fmt.Errorf("stripe event %q not found", eventID) + } + return nil +} + +func scanAccount(s scanner) (*Account, error) { + var a Account + var kind string + var createdAt, updatedAt int64 + if err := s.Scan(&a.ID, &kind, &a.DisplayName, &createdAt, &updatedAt); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("scan account: %w", err) + } + a.Kind = AccountKind(kind) + a.CreatedAt = time.Unix(createdAt, 0).UTC() + a.UpdatedAt = time.Unix(updatedAt, 0).UTC() + return &a, nil +} + +func scanAccounts(rows *sql.Rows) ([]*Account, error) { + var accounts []*Account + for rows.Next() { + a, err := scanAccount(rows) + if err != nil { + return nil, err + } + accounts = append(accounts, a) + } + return accounts, rows.Err() +} + +func scanUser(s scanner) (*User, error) { + var u User + var createdAt int64 + var lastLogin sql.NullInt64 + if err := s.Scan(&u.ID, &u.Email, &createdAt, &lastLogin); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("scan user: %w", err) + } + u.CreatedAt = time.Unix(createdAt, 0).UTC() + if lastLogin.Valid { + ts := time.Unix(lastLogin.Int64, 0).UTC() + u.LastLoginAt = &ts + } + return &u, nil +} + +func scanMembership(s scanner) (*AccountMembership, error) { + var m AccountMembership + var role string + var createdAt int64 + if err := s.Scan(&m.AccountID, &m.UserID, &role, &createdAt); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("scan membership: %w", err) + } + m.Role = MemberRole(role) + m.CreatedAt = time.Unix(createdAt, 0).UTC() + return &m, nil +} + +func scanMemberships(rows *sql.Rows) ([]*AccountMembership, error) { + var memberships []*AccountMembership + for rows.Next() { + m, err := scanMembership(rows) + if err != nil { + return nil, err + } + memberships = append(memberships, m) + } + return memberships, rows.Err() +} + func nullableTimeUnix(t *time.Time) any { if t == nil { return nil @@ -286,9 +864,60 @@ func nullableTimeUnix(t *time.Time) any { return t.Unix() } +func nullableInt64Ptr(v *int64) any { + if v == nil { + return nil + } + return *v +} + +func nullableString(s string) any { + if strings.TrimSpace(s) == "" { + return nil + } + return strings.TrimSpace(s) +} + func boolToInt(b bool) int { if b { return 1 } return 0 } + +func scanStripeAccount(s scanner) (*StripeAccount, error) { + var sa StripeAccount + var subID, subItemID sql.NullString + var trialEnds, periodEnd sql.NullInt64 + if err := s.Scan( + &sa.AccountID, + &sa.StripeCustomerID, + &subID, + &subItemID, + &sa.PlanVersion, + &sa.SubscriptionState, + &trialEnds, + &periodEnd, + &sa.UpdatedAt, + ); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("scan stripe account: %w", err) + } + if subID.Valid { + sa.StripeSubscriptionID = subID.String + } + if subItemID.Valid { + sa.StripeSubItemWorkspacesID = subItemID.String + } + if trialEnds.Valid { + v := trialEnds.Int64 + sa.TrialEndsAt = &v + } + if periodEnd.Valid { + v := periodEnd.Int64 + sa.CurrentPeriodEnd = &v + } + return &sa, nil +} diff --git a/internal/cloudcp/registry/registry_test.go b/internal/cloudcp/registry/registry_test.go index 010e90ed5..976770831 100644 --- a/internal/cloudcp/registry/registry_test.go +++ b/internal/cloudcp/registry/registry_test.go @@ -59,6 +59,32 @@ func TestGenerateTenantID_CrockfordCharset(t *testing.T) { } } +func TestGenerateAccountID(t *testing.T) { + id, err := GenerateAccountID() + if err != nil { + t.Fatalf("GenerateAccountID: %v", err) + } + if !strings.HasPrefix(id, "a_") { + t.Errorf("expected prefix a_, got %q", id) + } + if len(id) != 12 { // "a_" + 10 chars + t.Errorf("expected length 12, got %d (%q)", len(id), id) + } +} + +func TestGenerateUserID(t *testing.T) { + id, err := GenerateUserID() + if err != nil { + t.Fatalf("GenerateUserID: %v", err) + } + if !strings.HasPrefix(id, "u_") { + t.Errorf("expected prefix u_, got %q", id) + } + if len(id) != 12 { // "u_" + 10 chars + t.Errorf("expected length 12, got %d (%q)", len(id), id) + } +} + func TestCRUD(t *testing.T) { reg := newTestRegistry(t) @@ -137,6 +163,226 @@ func TestCRUD(t *testing.T) { } } +func TestAccountCRUD(t *testing.T) { + reg := newTestRegistry(t) + + accountID, err := GenerateAccountID() + if err != nil { + t.Fatal(err) + } + a := &Account{ + ID: accountID, + Kind: AccountKindMSP, + DisplayName: "Test MSP", + } + + // Create + if err := reg.CreateAccount(a); err != nil { + t.Fatalf("CreateAccount: %v", err) + } + if a.CreatedAt.IsZero() { + t.Error("CreatedAt should be set") + } + if a.UpdatedAt.IsZero() { + t.Error("UpdatedAt should be set") + } + + // Get + got, err := reg.GetAccount(accountID) + if err != nil { + t.Fatalf("GetAccount: %v", err) + } + if got == nil { + t.Fatal("GetAccount returned nil") + } + if got.Kind != AccountKindMSP { + t.Errorf("Kind = %q, want %q", got.Kind, AccountKindMSP) + } + if got.DisplayName != "Test MSP" { + t.Errorf("DisplayName = %q, want %q", got.DisplayName, "Test MSP") + } + + // Update + got.DisplayName = "Renamed MSP" + if err := reg.UpdateAccount(got); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + got2, err := reg.GetAccount(accountID) + if err != nil { + t.Fatalf("GetAccount after update: %v", err) + } + if got2.DisplayName != "Renamed MSP" { + t.Errorf("DisplayName after update = %q, want %q", got2.DisplayName, "Renamed MSP") + } + + // List + accounts, err := reg.ListAccounts() + if err != nil { + t.Fatalf("ListAccounts: %v", err) + } + if len(accounts) != 1 { + t.Fatalf("expected 1 account, got %d", len(accounts)) + } + if accounts[0].ID != accountID { + t.Errorf("accounts[0].ID = %q, want %q", accounts[0].ID, accountID) + } +} + +func TestUserCRUD(t *testing.T) { + reg := newTestRegistry(t) + + userID, err := GenerateUserID() + if err != nil { + t.Fatal(err) + } + u := &User{ + ID: userID, + Email: "user@example.com", + } + + // Create + if err := reg.CreateUser(u); err != nil { + t.Fatalf("CreateUser: %v", err) + } + if u.CreatedAt.IsZero() { + t.Error("CreatedAt should be set") + } + + // Get by ID + got, err := reg.GetUser(userID) + if err != nil { + t.Fatalf("GetUser: %v", err) + } + if got == nil { + t.Fatal("GetUser returned nil") + } + if got.Email != "user@example.com" { + t.Errorf("Email = %q, want %q", got.Email, "user@example.com") + } + if got.LastLoginAt != nil { + t.Errorf("LastLoginAt = %v, want nil", got.LastLoginAt) + } + + // Get by email + got2, err := reg.GetUserByEmail("user@example.com") + if err != nil { + t.Fatalf("GetUserByEmail: %v", err) + } + if got2 == nil || got2.ID != userID { + t.Fatalf("GetUserByEmail returned %+v, want id=%q", got2, userID) + } + + // Update last login + before := time.Now().UTC() + if err := reg.UpdateUserLastLogin(userID); err != nil { + t.Fatalf("UpdateUserLastLogin: %v", err) + } + after := time.Now().UTC() + + got3, err := reg.GetUser(userID) + if err != nil { + t.Fatalf("GetUser after last login update: %v", err) + } + if got3.LastLoginAt == nil { + t.Fatal("LastLoginAt should be set") + } + ll := *got3.LastLoginAt + if ll.Before(before.Add(-2*time.Second)) || ll.After(after.Add(2*time.Second)) { + t.Fatalf("LastLoginAt=%s out of expected range [%s, %s]", ll, before, after) + } +} + +func TestMembershipCRUD(t *testing.T) { + reg := newTestRegistry(t) + + accountID, err := GenerateAccountID() + if err != nil { + t.Fatal(err) + } + userID, err := GenerateUserID() + if err != nil { + t.Fatal(err) + } + + if err := reg.CreateAccount(&Account{ID: accountID, Kind: AccountKindMSP, DisplayName: "Account"}); err != nil { + t.Fatal(err) + } + if err := reg.CreateUser(&User{ID: userID, Email: "member@example.com"}); err != nil { + t.Fatal(err) + } + + m := &AccountMembership{ + AccountID: accountID, + UserID: userID, + Role: MemberRoleOwner, + } + + // Create + if err := reg.CreateMembership(m); err != nil { + t.Fatalf("CreateMembership: %v", err) + } + if m.CreatedAt.IsZero() { + t.Error("CreatedAt should be set") + } + + // Get + got, err := reg.GetMembership(accountID, userID) + if err != nil { + t.Fatalf("GetMembership: %v", err) + } + if got == nil { + t.Fatal("GetMembership returned nil") + } + if got.Role != MemberRoleOwner { + t.Errorf("Role = %q, want %q", got.Role, MemberRoleOwner) + } + + // List by account + members, err := reg.ListMembersByAccount(accountID) + if err != nil { + t.Fatalf("ListMembersByAccount: %v", err) + } + if len(members) != 1 { + t.Fatalf("expected 1 member, got %d", len(members)) + } + if members[0].UserID != userID { + t.Errorf("members[0].UserID = %q, want %q", members[0].UserID, userID) + } + + // List accounts by user + accounts, err := reg.ListAccountsByUser(userID) + if err != nil { + t.Fatalf("ListAccountsByUser: %v", err) + } + if len(accounts) != 1 || accounts[0] != accountID { + t.Fatalf("accounts=%v, want [%q]", accounts, accountID) + } + + // Update role + if err := reg.UpdateMembershipRole(accountID, userID, MemberRoleAdmin); err != nil { + t.Fatalf("UpdateMembershipRole: %v", err) + } + got2, err := reg.GetMembership(accountID, userID) + if err != nil { + t.Fatalf("GetMembership after role update: %v", err) + } + if got2.Role != MemberRoleAdmin { + t.Errorf("Role after update = %q, want %q", got2.Role, MemberRoleAdmin) + } + + // Delete + if err := reg.DeleteMembership(accountID, userID); err != nil { + t.Fatalf("DeleteMembership: %v", err) + } + got3, err := reg.GetMembership(accountID, userID) + if err != nil { + t.Fatalf("GetMembership after delete: %v", err) + } + if got3 != nil { + t.Fatalf("expected nil membership after delete, got %+v", got3) + } +} + func TestList(t *testing.T) { reg := newTestRegistry(t) @@ -199,6 +445,46 @@ func TestListByState(t *testing.T) { } } +func TestListByAccountID(t *testing.T) { + reg := newTestRegistry(t) + + accountID, err := GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(&Account{ID: accountID, Kind: AccountKindMSP, DisplayName: "Account"}); err != nil { + t.Fatal(err) + } + + if err := reg.Create(&Tenant{ID: "t-ACCNT0001", AccountID: accountID, State: TenantStateActive}); err != nil { + t.Fatal(err) + } + if err := reg.Create(&Tenant{ID: "t-ACCNT0002", AccountID: accountID, State: TenantStateActive}); err != nil { + t.Fatal(err) + } + if err := reg.Create(&Tenant{ID: "t-ACCNT0003", State: TenantStateActive}); err != nil { + t.Fatal(err) + } + + tenants, err := reg.ListByAccountID(accountID) + if err != nil { + t.Fatalf("ListByAccountID: %v", err) + } + if len(tenants) != 2 { + t.Fatalf("expected 2 tenants, got %d", len(tenants)) + } + seen := make(map[string]bool) + for _, tnt := range tenants { + seen[tnt.ID] = true + if tnt.AccountID != accountID { + t.Errorf("tenant %s AccountID=%q, want %q", tnt.ID, tnt.AccountID, accountID) + } + } + if !seen["t-ACCNT0001"] || !seen["t-ACCNT0002"] { + t.Fatalf("expected tenants t-ACCNT0001 and t-ACCNT0002, got %+v", tenants) + } +} + func TestCountByState(t *testing.T) { reg := newTestRegistry(t) @@ -280,3 +566,107 @@ func TestNewTenantRegistry_InvalidDir(t *testing.T) { } } } + +func TestStripeAccountCRUD(t *testing.T) { + reg := newTestRegistry(t) + + accountID, err := GenerateAccountID() + if err != nil { + t.Fatal(err) + } + if err := reg.CreateAccount(&Account{ + ID: accountID, + Kind: AccountKindMSP, + DisplayName: "Test MSP", + }); err != nil { + t.Fatal(err) + } + + trialEnds := time.Now().UTC().Add(7 * 24 * time.Hour).Unix() + periodEnd := time.Now().UTC().Add(30 * 24 * time.Hour).Unix() + + sa := &StripeAccount{ + AccountID: accountID, + StripeCustomerID: "cus_test_123", + StripeSubscriptionID: "sub_test_123", + StripeSubItemWorkspacesID: "si_workspaces_123", + PlanVersion: "msp_hosted_v1", + SubscriptionState: "trial", + TrialEndsAt: &trialEnds, + CurrentPeriodEnd: &periodEnd, + } + + // Create + if err := reg.CreateStripeAccount(sa); err != nil { + t.Fatalf("CreateStripeAccount: %v", err) + } + + // Get by account id + got, err := reg.GetStripeAccount(accountID) + if err != nil { + t.Fatalf("GetStripeAccount: %v", err) + } + if got == nil { + t.Fatal("GetStripeAccount returned nil") + } + if got.StripeCustomerID != "cus_test_123" { + t.Errorf("StripeCustomerID = %q, want %q", got.StripeCustomerID, "cus_test_123") + } + if got.StripeSubscriptionID != "sub_test_123" { + t.Errorf("StripeSubscriptionID = %q, want %q", got.StripeSubscriptionID, "sub_test_123") + } + + // Get by customer id + got2, err := reg.GetStripeAccountByCustomerID("cus_test_123") + if err != nil { + t.Fatalf("GetStripeAccountByCustomerID: %v", err) + } + if got2 == nil || got2.AccountID != accountID { + t.Fatalf("expected accountID %q, got %#v", accountID, got2) + } + + // Update + got2.SubscriptionState = "active" + got2.PlanVersion = "msp_hosted_v2" + got2.StripeSubscriptionID = "sub_test_456" + if err := reg.UpdateStripeAccount(got2); err != nil { + t.Fatalf("UpdateStripeAccount: %v", err) + } + + got3, err := reg.GetStripeAccount(accountID) + if err != nil { + t.Fatalf("GetStripeAccount after update: %v", err) + } + if got3.SubscriptionState != "active" { + t.Errorf("SubscriptionState = %q, want %q", got3.SubscriptionState, "active") + } + if got3.PlanVersion != "msp_hosted_v2" { + t.Errorf("PlanVersion = %q, want %q", got3.PlanVersion, "msp_hosted_v2") + } + if got3.StripeSubscriptionID != "sub_test_456" { + t.Errorf("StripeSubscriptionID = %q, want %q", got3.StripeSubscriptionID, "sub_test_456") + } + if got3.UpdatedAt == 0 { + t.Error("UpdatedAt should be set") + } +} + +func TestStripeEventIdempotency(t *testing.T) { + reg := newTestRegistry(t) + + already, err := reg.RecordStripeEvent("evt_test_123", "customer.subscription.updated") + if err != nil { + t.Fatalf("RecordStripeEvent: %v", err) + } + if already { + t.Fatalf("expected alreadyProcessed=false on first insert") + } + + already2, err := reg.RecordStripeEvent("evt_test_123", "customer.subscription.updated") + if err != nil { + t.Fatalf("RecordStripeEvent duplicate: %v", err) + } + if !already2 { + t.Fatalf("expected alreadyProcessed=true on duplicate insert") + } +} diff --git a/internal/cloudcp/routes.go b/internal/cloudcp/routes.go index 425663cf4..e55b66048 100644 --- a/internal/cloudcp/routes.go +++ b/internal/cloudcp/routes.go @@ -7,6 +7,8 @@ import ( "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin" cpauth "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/auth" "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/docker" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/handoff" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/portal" "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe" ) @@ -99,4 +101,12 @@ func RegisterRoutes(mux *http.ServeMux, deps *Deps) { mux.Handle("/api/accounts/{account_id}/tenants", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenantsCollection)) mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenant)) + + // Tenant switching handoff (admin-key authenticated for now; session auth in M-4) + handoffHandler := handoff.HandleHandoff(deps.Registry, deps.Config.TenantsDir()) + mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}/handoff", admin.AdminKeyMiddleware(deps.Config.AdminKey, handoffHandler)) + + // MSP portal API (admin-key authenticated for now; session auth in M-4) + mux.Handle("/api/portal/dashboard", admin.AdminKeyMiddleware(deps.Config.AdminKey, portal.HandlePortalDashboard(deps.Registry))) + mux.Handle("/api/portal/workspaces/{tenant_id}", admin.AdminKeyMiddleware(deps.Config.AdminKey, portal.HandlePortalWorkspaceDetail(deps.Registry))) } diff --git a/internal/cloudcp/stripe/webhook.go b/internal/cloudcp/stripe/webhook.go index 96cd2fc4f..a91e3fbab 100644 --- a/internal/cloudcp/stripe/webhook.go +++ b/internal/cloudcp/stripe/webhook.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry" "github.com/rs/zerolog/log" stripelib "github.com/stripe/stripe-go/v82" "github.com/stripe/stripe-go/v82/webhook" @@ -98,14 +99,14 @@ func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) er if err := json.Unmarshal(event.Data.Raw, &sub); err != nil { return fmt.Errorf("decode subscription: %w", err) } - return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub) + return h.routeSubscriptionUpdated(r, sub) case "customer.subscription.deleted": var sub Subscription if err := json.Unmarshal(event.Data.Raw, &sub); err != nil { return fmt.Errorf("decode subscription: %w", err) } - return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub) + return h.routeSubscriptionDeleted(r, sub) default: log.Info(). @@ -116,6 +117,46 @@ func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) er } } +func (h *WebhookHandler) routeSubscriptionUpdated(r *http.Request, sub Subscription) error { + customerID := strings.TrimSpace(sub.Customer) + if customerID != "" { + sa, err := h.provisioner.registry.GetStripeAccountByCustomerID(customerID) + if err != nil { + return fmt.Errorf("lookup stripe account by customer: %w", err) + } + if sa != nil { + acct, err := h.provisioner.registry.GetAccount(sa.AccountID) + if err != nil { + return fmt.Errorf("lookup account: %w", err) + } + if acct != nil && acct.Kind == registry.AccountKindMSP { + return h.provisioner.HandleMSPSubscriptionUpdated(r.Context(), sub) + } + } + } + return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub) +} + +func (h *WebhookHandler) routeSubscriptionDeleted(r *http.Request, sub Subscription) error { + customerID := strings.TrimSpace(sub.Customer) + if customerID != "" { + sa, err := h.provisioner.registry.GetStripeAccountByCustomerID(customerID) + if err != nil { + return fmt.Errorf("lookup stripe account by customer: %w", err) + } + if sa != nil { + acct, err := h.provisioner.registry.GetAccount(sa.AccountID) + if err != nil { + return fmt.Errorf("lookup account: %w", err) + } + if acct != nil && acct.Kind == registry.AccountKindMSP { + return h.provisioner.HandleMSPSubscriptionDeleted(r.Context(), sub) + } + } + } + return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub) +} + // CheckoutSession is a minimal representation of a Stripe checkout.session event. type CheckoutSession struct { ID string `json:"id"`