feat(cloud): implement Hosted MSP account layer (M-1 through M-6)

Add the full Hosted MSP portal backend to the control plane. MSP accounts
own multiple container-per-tenant workspaces with consolidated billing,
account-level RBAC, and signed JWT handoff for tenant switching.

- M-1: Account model + schema (accounts, users, account_memberships tables,
  tenant.account_id FK, ID generators, full CRUD)
- M-2: Account-level RBAC handlers (invite/update/remove members, last-owner
  protection)
- M-3: Portal dashboard API (workspace list with fleet health summaries,
  workspace detail)
- M-4: Tenant handoff auth (JWT HS256 minting with per-tenant secret,
  auto-submit HTML POST binding, tenant-side exchange endpoint with JTI
  replay protection)
- M-5: Workspace provisioning from portal (create/list/update/delete
  workspaces, handoff.key generation, billing.json, Docker lifecycle)
- M-6: Consolidated billing (stripe_accounts table, MSP-aware webhook
  routing, subscription state propagation across account tenants)
This commit is contained in:
rcourtman
2026-02-10 22:10:24 +00:00
parent 463e4eff50
commit 0df04de4bf
18 changed files with 3684 additions and 12 deletions

1
go.mod
View File

@@ -59,6 +59,7 @@ require (
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.7.0 // indirect

2
go.sum
View File

@@ -76,6 +76,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=

View File

@@ -0,0 +1,233 @@
package api
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
_ "modernc.org/sqlite"
)
const (
cloudHandoffIssuer = "pulse-cloud-control-plane"
)
type cloudHandoffClaims struct {
AccountID string `json:"account_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
type jtiReplayStore struct {
once sync.Once
db *sql.DB
mu sync.Mutex
configDir string
initErr error
}
func (s *jtiReplayStore) init() {
s.once.Do(func() {
dir := filepath.Clean(s.configDir)
if strings.TrimSpace(dir) == "" {
s.initErr = fmt.Errorf("configDir is required")
return
}
if err := os.MkdirAll(dir, 0o755); err != nil {
s.initErr = fmt.Errorf("create config dir: %w", err)
return
}
dbPath := filepath.Join(dir, "handoff_jti.db")
dsn := dbPath + "?" + url.Values{
"_pragma": []string{
"busy_timeout(30000)",
"journal_mode(WAL)",
"synchronous(NORMAL)",
},
}.Encode()
db, err := sql.Open("sqlite", dsn)
if err != nil {
s.initErr = fmt.Errorf("open handoff jti db: %w", err)
return
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
schema := `
CREATE TABLE IF NOT EXISTS handoff_jti (
jti TEXT PRIMARY KEY,
expires_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_handoff_jti_expires_at ON handoff_jti(expires_at);
`
if _, err := db.Exec(schema); err != nil {
_ = db.Close()
s.initErr = fmt.Errorf("init handoff jti schema: %w", err)
return
}
s.db = db
})
}
func (s *jtiReplayStore) checkAndStore(jti string, expiresAt time.Time) (stored bool, err error) {
s.init()
if s.initErr != nil {
return false, s.initErr
}
if s.db == nil {
return false, fmt.Errorf("handoff jti store not initialized")
}
jti = strings.TrimSpace(jti)
if jti == "" {
return false, fmt.Errorf("jti is required")
}
expiresAt = expiresAt.UTC()
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now().UTC().Unix()
if _, err := s.db.Exec(`DELETE FROM handoff_jti WHERE expires_at <= ?`, now); err != nil {
return false, fmt.Errorf("cleanup handoff jti: %w", err)
}
_, err = s.db.Exec(`INSERT INTO handoff_jti (jti, expires_at) VALUES (?, ?)`, jti, expiresAt.Unix())
if err != nil {
if isSQLiteUniqueViolation(err) {
return false, nil
}
return false, fmt.Errorf("store handoff jti: %w", err)
}
return true, nil
}
func isSQLiteUniqueViolation(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "UNIQUE constraint failed") || strings.Contains(s, "constraint failed")
}
func tenantIDFromRequest(r *http.Request) string {
if v := strings.TrimSpace(os.Getenv("PULSE_TENANT_ID")); v != "" {
return v
}
host := strings.TrimSpace(r.Host)
if host == "" {
return ""
}
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// Host is expected to be "<tenant-id>.<baseDomain>".
if i := strings.IndexByte(host, '.'); i > 0 {
return host[:i]
}
return host
}
// HandleHandoffExchange verifies a control-plane-minted handoff JWT and records its jti to prevent replay.
//
// Route (wiring happens elsewhere): POST /api/cloud/handoff/exchange
//
// This minimal implementation returns success JSON containing user info derived from the token.
// Wiring into RBAC/session minting is intentionally deferred.
func HandleHandoffExchange(configDir string) http.HandlerFunc {
configDir = filepath.Clean(configDir)
keyPath := filepath.Join(configDir, "secrets", "handoff.key")
replay := &jtiReplayStore{configDir: configDir}
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
key, err := os.ReadFile(keyPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
http.Error(w, "not found", http.StatusNotFound)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
tenantID := tenantIDFromRequest(r)
if tenantID == "" {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
tokenStr := strings.TrimSpace(r.FormValue("token"))
if tokenStr == "" {
http.Error(w, "missing token", http.StatusBadRequest)
return
}
var claims cloudHandoffClaims
parsed, err := jwt.ParseWithClaims(
tokenStr,
&claims,
func(t *jwt.Token) (any, error) { return key, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(cloudHandoffIssuer),
jwt.WithAudience(tenantID),
)
if err != nil || parsed == nil || !parsed.Valid {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
if claims.ExpiresAt == nil {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
ok, err := replay.checkAndStore(claims.ID, claims.ExpiresAt.Time)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if !ok {
http.Error(w, "replayed token", http.StatusUnauthorized)
return
}
resp := map[string]any{
"ok": true,
"tenant_id": tenantID,
"account_id": claims.AccountID,
"user_id": claims.Subject,
"email": claims.Email,
"role": claims.Role,
"jti": claims.ID,
"exp": claims.ExpiresAt.Time.UTC().Format(time.RFC3339),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}

View File

@@ -0,0 +1,346 @@
package account
import (
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
type memberResponse struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role registry.MemberRole `json:"role"`
CreatedAt time.Time `json:"created_at"`
}
// HandleListMembers returns an authenticated handler that lists all members of an account.
func HandleListMembers(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
if accountID == "" {
http.Error(w, "missing account_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
memberships, err := reg.ListMembersByAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if memberships == nil {
memberships = []*registry.AccountMembership{}
}
resp := make([]memberResponse, 0, len(memberships))
for _, m := range memberships {
u, err := reg.GetUser(m.UserID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if u == nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
resp = append(resp, memberResponse{
UserID: m.UserID,
Email: u.Email,
Role: m.Role,
CreatedAt: m.CreatedAt,
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
}
type inviteMemberRequest struct {
Email string `json:"email"`
Role string `json:"role"`
}
// HandleInviteMember returns an authenticated handler that invites a user to an account.
func HandleInviteMember(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
if accountID == "" {
http.Error(w, "missing account_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
var req inviteMemberRequest
if err := decodeJSON(w, r, &req); err != nil {
return
}
email := normalizeEmail(req.Email)
if email == "" {
http.Error(w, "invalid email", http.StatusBadRequest)
return
}
role, ok := parseMemberRole(req.Role)
if !ok {
http.Error(w, "invalid role", http.StatusBadRequest)
return
}
u, err := reg.GetUserByEmail(email)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if u == nil {
userID, err := registry.GenerateUserID()
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
u = &registry.User{
ID: userID,
Email: email,
}
if err := reg.CreateUser(u); err != nil {
// If a concurrent request created the user, fall back to lookup.
u2, gerr := reg.GetUserByEmail(email)
if gerr != nil || u2 == nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
u = u2
}
}
if err := reg.CreateMembership(&registry.AccountMembership{
AccountID: accountID,
UserID: u.ID,
Role: role,
}); err != nil {
if isUniqueViolation(err) {
http.Error(w, "membership already exists", http.StatusConflict)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
}
}
type updateMemberRoleRequest struct {
Role string `json:"role"`
}
// HandleUpdateMemberRole returns an authenticated handler that updates a member's role.
func HandleUpdateMemberRole(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
userID := strings.TrimSpace(r.PathValue("user_id"))
if accountID == "" || userID == "" {
http.Error(w, "missing account_id or user_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
var req updateMemberRoleRequest
if err := decodeJSON(w, r, &req); err != nil {
return
}
role, ok := parseMemberRole(req.Role)
if !ok {
http.Error(w, "invalid role", http.StatusBadRequest)
return
}
if err := reg.UpdateMembershipRole(accountID, userID, role); err != nil {
if isNotFoundErr(err) {
http.Error(w, "membership not found", http.StatusNotFound)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
}
// HandleRemoveMember returns an authenticated handler that removes a user from an account.
func HandleRemoveMember(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
userID := strings.TrimSpace(r.PathValue("user_id"))
if accountID == "" || userID == "" {
http.Error(w, "missing account_id or user_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
m, err := reg.GetMembership(accountID, userID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if m == nil {
http.Error(w, "membership not found", http.StatusNotFound)
return
}
if m.Role == registry.MemberRoleOwner {
memberships, err := reg.ListMembersByAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
owners := 0
for _, mm := range memberships {
if mm.Role == registry.MemberRoleOwner {
owners++
}
}
if owners <= 1 {
http.Error(w, "cannot remove last owner", http.StatusConflict)
return
}
}
if err := reg.DeleteMembership(accountID, userID); err != nil {
if isNotFoundErr(err) {
http.Error(w, "membership not found", http.StatusNotFound)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
func normalizeEmail(s string) string {
s = strings.TrimSpace(s)
s = strings.ToLower(s)
// Minimal sanity; deeper validation comes later with session auth flows.
if s == "" || !strings.Contains(s, "@") {
return ""
}
return s
}
func parseMemberRole(s string) (registry.MemberRole, bool) {
switch registry.MemberRole(strings.TrimSpace(s)) {
case registry.MemberRoleOwner:
return registry.MemberRoleOwner, true
case registry.MemberRoleAdmin:
return registry.MemberRoleAdmin, true
case registry.MemberRoleTech:
return registry.MemberRoleTech, true
case registry.MemberRoleReadOnly:
return registry.MemberRoleReadOnly, true
default:
return "", false
}
}
func decodeJSON(w http.ResponseWriter, r *http.Request, dst any) error {
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(dst); err != nil {
http.Error(w, "invalid JSON body", http.StatusBadRequest)
return err
}
if err := dec.Decode(&struct{}{}); err != io.EOF {
if err == nil {
http.Error(w, "invalid JSON body", http.StatusBadRequest)
return errors.New("multiple JSON values")
}
http.Error(w, "invalid JSON body", http.StatusBadRequest)
return err
}
return nil
}
func isNotFoundErr(err error) bool {
if err == nil {
return false
}
// Registry uses fmt.Errorf("... not found") (no sentinel errors yet).
return strings.Contains(err.Error(), "not found")
}
func isUniqueViolation(err error) bool {
if err == nil {
return false
}
// modernc.org/sqlite returns strings containing "UNIQUE constraint failed".
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unique constraint failed")
}

View File

@@ -0,0 +1,343 @@
package account
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"sort"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
t.Helper()
dir := t.TempDir()
reg, err := registry.NewTenantRegistry(dir)
if err != nil {
t.Fatalf("NewTenantRegistry: %v", err)
}
t.Cleanup(func() { _ = reg.Close() })
return reg
}
func newTestMux(reg *registry.TenantRegistry) *http.ServeMux {
mux := http.NewServeMux()
listMembers := HandleListMembers(reg)
inviteMember := HandleInviteMember(reg)
updateRole := HandleUpdateMemberRole(reg)
removeMember := HandleRemoveMember(reg)
collection := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
listMembers(w, r)
case http.MethodPost:
inviteMember(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
member := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPatch:
updateRole(w, r)
case http.MethodDelete:
removeMember(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
mux.Handle("/api/accounts/{account_id}/members", admin.AdminKeyMiddleware("secret-key", collection))
mux.Handle("/api/accounts/{account_id}/members/{user_id}", admin.AdminKeyMiddleware("secret-key", member))
return mux
}
func doRequest(t *testing.T, h http.Handler, req *http.Request) *httptest.ResponseRecorder {
t.Helper()
req.Header.Set("X-Admin-Key", "secret-key")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
return rec
}
func TestInviteMember(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
body := `{"email":"tech@msp.com","role":"tech"}`
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/members", bytes.NewBufferString(body))
rec := doRequest(t, mux, req)
if rec.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String())
}
u, err := reg.GetUserByEmail("tech@msp.com")
if err != nil {
t.Fatal(err)
}
if u == nil {
t.Fatal("expected user to be created")
}
m, err := reg.GetMembership(accountID, u.ID)
if err != nil {
t.Fatal(err)
}
if m == nil {
t.Fatal("expected membership to be created")
}
if m.Role != registry.MemberRoleTech {
t.Fatalf("role = %q, want %q", m.Role, registry.MemberRoleTech)
}
}
func TestInviteExistingUser(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
userID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: userID, Email: "existing@msp.com"}); err != nil {
t.Fatal(err)
}
body := `{"email":"existing@msp.com","role":"tech"}`
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/members", bytes.NewBufferString(body))
rec := doRequest(t, mux, req)
if rec.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String())
}
u, err := reg.GetUserByEmail("existing@msp.com")
if err != nil {
t.Fatal(err)
}
if u == nil || u.ID != userID {
t.Fatalf("user = %+v, want id=%q", u, userID)
}
m, err := reg.GetMembership(accountID, userID)
if err != nil {
t.Fatal(err)
}
if m == nil {
t.Fatal("expected membership to be created")
}
}
func TestListMembers(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
u1ID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
u2ID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: u1ID, Email: "owner@msp.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: u2ID, Email: "tech@msp.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: u1ID, Role: registry.MemberRoleOwner}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: u2ID, Role: registry.MemberRoleTech}); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/api/accounts/"+accountID+"/members", nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
var got []struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role registry.MemberRole `json:"role"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
t.Fatalf("decode response: %v", err)
}
if len(got) != 2 {
t.Fatalf("expected 2 members, got %d (%+v)", len(got), got)
}
sort.Slice(got, func(i, j int) bool { return got[i].Email < got[j].Email })
if got[0].Email != "owner@msp.com" || got[0].Role != registry.MemberRoleOwner {
t.Fatalf("member[0]=%+v, want owner@msp.com owner", got[0])
}
if got[1].Email != "tech@msp.com" || got[1].Role != registry.MemberRoleTech {
t.Fatalf("member[1]=%+v, want tech@msp.com tech", got[1])
}
}
func TestUpdateMemberRole(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
userID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: userID, Email: "tech@msp.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: userID, Role: registry.MemberRoleTech}); err != nil {
t.Fatal(err)
}
body := `{"role":"admin"}`
req := httptest.NewRequest(http.MethodPatch, "/api/accounts/"+accountID+"/members/"+userID, bytes.NewBufferString(body))
rec := doRequest(t, mux, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
m, err := reg.GetMembership(accountID, userID)
if err != nil {
t.Fatal(err)
}
if m == nil {
t.Fatal("expected membership to exist")
}
if m.Role != registry.MemberRoleAdmin {
t.Fatalf("role = %q, want %q", m.Role, registry.MemberRoleAdmin)
}
}
func TestRemoveMember(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
ownerID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
techID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: ownerID, Email: "owner@msp.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: techID, Email: "tech@msp.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: ownerID, Role: registry.MemberRoleOwner}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: techID, Role: registry.MemberRoleTech}); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/members/"+techID, nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusNoContent, rec.Body.String())
}
m, err := reg.GetMembership(accountID, techID)
if err != nil {
t.Fatal(err)
}
if m != nil {
t.Fatalf("expected membership to be deleted, got %+v", m)
}
}
func TestCannotRemoveLastOwner(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
ownerID, err := registry.GenerateUserID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: ownerID, Email: "owner@msp.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: ownerID, Role: registry.MemberRoleOwner}); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/members/"+ownerID, nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusConflict && rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d or %d (body=%q)", rec.Code, http.StatusConflict, http.StatusBadRequest, rec.Body.String())
}
m, err := reg.GetMembership(accountID, ownerID)
if err != nil {
t.Fatal(err)
}
if m == nil {
t.Fatal("expected owner membership to remain")
}
}

View File

@@ -0,0 +1,279 @@
package account
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
// WorkspaceProvisioner is the minimal interface needed by the MSP portal tenant handlers.
// Implemented by internal/cloudcp/stripe.Provisioner.
type WorkspaceProvisioner interface {
ProvisionWorkspace(ctx context.Context, accountID, displayName string) (*registry.Tenant, error)
DeprovisionWorkspaceContainer(ctx context.Context, tenant *registry.Tenant) error
}
// HandleListTenants lists all tenants for an account.
// Route: GET /api/accounts/{account_id}/tenants
func HandleListTenants(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
if accountID == "" {
http.Error(w, "missing account_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
tenants, err := reg.ListByAccountID(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if tenants == nil {
tenants = []*registry.Tenant{}
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(tenants)
}
}
type createTenantRequest struct {
DisplayName string `json:"display_name"`
}
// HandleCreateTenant creates a new tenant under an account.
// Route: POST /api/accounts/{account_id}/tenants
func HandleCreateTenant(reg *registry.TenantRegistry, provisioner WorkspaceProvisioner) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
if accountID == "" {
http.Error(w, "missing account_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
var req createTenantRequest
if err := decodeJSON(w, r, &req); err != nil {
return
}
displayName := strings.TrimSpace(req.DisplayName)
if displayName == "" {
http.Error(w, "invalid display_name", http.StatusBadRequest)
return
}
if provisioner == nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
tenant, err := provisioner.ProvisionWorkspace(r.Context(), accountID, displayName)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(tenant)
}
}
type updateTenantRequest struct {
DisplayName *string `json:"display_name,omitempty"`
Status *string `json:"status,omitempty"`
State *string `json:"state,omitempty"`
}
func parseTenantState(s string) (registry.TenantState, bool) {
switch registry.TenantState(strings.TrimSpace(s)) {
case registry.TenantStateActive:
return registry.TenantStateActive, true
case registry.TenantStateSuspended:
return registry.TenantStateSuspended, true
default:
return "", false
}
}
func loadTenantForAccount(reg *registry.TenantRegistry, accountID, tenantID string) (*registry.Tenant, error) {
t, err := reg.Get(tenantID)
if err != nil {
return nil, err
}
if t == nil {
return nil, nil
}
if strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
return nil, nil
}
return t, nil
}
// HandleUpdateTenant updates display name and/or state.
// Route: PATCH /api/accounts/{account_id}/tenants/{tenant_id}
func HandleUpdateTenant(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
if accountID == "" || tenantID == "" {
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
tenant, err := loadTenantForAccount(reg, accountID, tenantID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if tenant == nil {
http.Error(w, "tenant not found", http.StatusNotFound)
return
}
var req updateTenantRequest
if err := decodeJSON(w, r, &req); err != nil {
return
}
if req.DisplayName != nil {
name := strings.TrimSpace(*req.DisplayName)
if name == "" {
http.Error(w, "invalid display_name", http.StatusBadRequest)
return
}
tenant.DisplayName = name
}
stateVal := req.Status
if stateVal == nil {
stateVal = req.State
}
if stateVal != nil {
st, ok := parseTenantState(*stateVal)
if !ok {
http.Error(w, "invalid status", http.StatusBadRequest)
return
}
tenant.State = st
}
if err := reg.Update(tenant); err != nil {
if isNotFoundErr(err) {
http.Error(w, "tenant not found", http.StatusNotFound)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(tenant)
}
}
// HandleDeleteTenant soft-deletes a tenant and deprovisions its container if Docker is available.
// Route: DELETE /api/accounts/{account_id}/tenants/{tenant_id}
func HandleDeleteTenant(reg *registry.TenantRegistry, provisioner WorkspaceProvisioner) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
if accountID == "" || tenantID == "" {
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
tenant, err := loadTenantForAccount(reg, accountID, tenantID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if tenant == nil {
http.Error(w, "tenant not found", http.StatusNotFound)
return
}
tenant.State = registry.TenantStateDeleting
if err := reg.Update(tenant); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if provisioner != nil {
if err := provisioner.DeprovisionWorkspaceContainer(r.Context(), tenant); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
}
tenant.ContainerID = ""
tenant.State = registry.TenantStateDeleted
if err := reg.Update(tenant); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}

View File

@@ -0,0 +1,220 @@
package account
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe"
)
func newTestTenantMux(reg *registry.TenantRegistry, tenantsDir string) (*http.ServeMux, *cpstripe.Provisioner) {
mux := http.NewServeMux()
provisioner := cpstripe.NewProvisioner(reg, tenantsDir, nil, nil, "https://cloud.example.com")
listTenants := HandleListTenants(reg)
createTenant := HandleCreateTenant(reg, provisioner)
updateTenant := HandleUpdateTenant(reg)
deleteTenant := HandleDeleteTenant(reg, provisioner)
collection := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
listTenants(w, r)
case http.MethodPost:
createTenant(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
tenant := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPatch:
updateTenant(w, r)
case http.MethodDelete:
deleteTenant(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
mux.Handle("/api/accounts/{account_id}/tenants", admin.AdminKeyMiddleware("secret-key", collection))
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}", admin.AdminKeyMiddleware("secret-key", tenant))
return mux, provisioner
}
func TestCreateWorkspace(t *testing.T) {
reg := newTestRegistry(t)
tenantsDir := t.TempDir()
mux, _ := newTestTenantMux(reg, tenantsDir)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil {
t.Fatal(err)
}
body := `{"display_name":"Acme Dental"}`
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/tenants", bytes.NewBufferString(body))
rec := doRequest(t, mux, req)
if rec.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String())
}
var got registry.Tenant
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
t.Fatalf("decode response: %v", err)
}
if got.AccountID != accountID {
t.Fatalf("account_id = %q, want %q", got.AccountID, accountID)
}
if got.DisplayName != "Acme Dental" {
t.Fatalf("display_name = %q, want %q", got.DisplayName, "Acme Dental")
}
keyPath := filepath.Join(tenantsDir, got.ID, "secrets", "handoff.key")
info, err := os.Stat(keyPath)
if err != nil {
t.Fatalf("handoff.key missing: %v", err)
}
if info.Mode().Perm() != 0o600 {
t.Fatalf("handoff.key perms = %o, want %o", info.Mode().Perm(), 0o600)
}
b, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read handoff.key: %v", err)
}
if len(b) != 32 {
t.Fatalf("handoff.key size = %d, want 32", len(b))
}
}
func TestListWorkspaces(t *testing.T) {
reg := newTestRegistry(t)
tenantsDir := t.TempDir()
mux, provisioner := newTestTenantMux(reg, tenantsDir)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil {
t.Fatal(err)
}
t1, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client One")
if err != nil {
t.Fatal(err)
}
t2, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client Two")
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/api/accounts/"+accountID+"/tenants", nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
var got []*registry.Tenant
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
t.Fatalf("decode response: %v", err)
}
if len(got) != 2 {
t.Fatalf("expected 2 tenants, got %d (%+v)", len(got), got)
}
ids := map[string]bool{}
for _, tt := range got {
if tt.AccountID != accountID {
t.Fatalf("tenant account_id = %q, want %q", tt.AccountID, accountID)
}
ids[tt.ID] = true
}
if !ids[t1.ID] || !ids[t2.ID] {
t.Fatalf("missing ids: got=%v want=%q,%q", ids, t1.ID, t2.ID)
}
}
func TestDeleteWorkspace(t *testing.T) {
reg := newTestRegistry(t)
tenantsDir := t.TempDir()
mux, provisioner := newTestTenantMux(reg, tenantsDir)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil {
t.Fatal(err)
}
tenant, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client")
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/tenants/"+tenant.ID, nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusNoContent, rec.Body.String())
}
t2, err := reg.Get(tenant.ID)
if err != nil {
t.Fatal(err)
}
if t2 == nil {
t.Fatal("expected tenant to exist")
}
if t2.State != registry.TenantStateDeleted {
t.Fatalf("state = %q, want %q", t2.State, registry.TenantStateDeleted)
}
}
func TestTenantBelongsToAccount(t *testing.T) {
reg := newTestRegistry(t)
tenantsDir := t.TempDir()
mux, provisioner := newTestTenantMux(reg, tenantsDir)
account1, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
account2, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: account1, Kind: registry.AccountKindMSP, DisplayName: "A1"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: account2, Kind: registry.AccountKindMSP, DisplayName: "A2"}); err != nil {
t.Fatal(err)
}
tenant, err := provisioner.ProvisionWorkspace(context.Background(), account1, "Client")
if err != nil {
t.Fatal(err)
}
body := `{"display_name":"New Name"}`
req := httptest.NewRequest(http.MethodPatch, "/api/accounts/"+account2+"/tenants/"+tenant.ID, bytes.NewBufferString(body))
rec := doRequest(t, mux, req)
if rec.Code != http.StatusNotFound && rec.Code != http.StatusForbidden {
t.Fatalf("status = %d, want 404/403 (body=%q)", rec.Code, rec.Body.String())
}
}

View File

@@ -0,0 +1,165 @@
package handoff
import (
"html/template"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
const (
handoffKeyRelPath = "secrets/handoff.key"
)
var handoffHTMLTemplate = template.Must(template.New("handoff").Parse(`<!DOCTYPE html>
<html><body>
<form method="POST" action="https://{{.TenantID}}.{{.BaseDomain}}/api/cloud/handoff/exchange">
<input type="hidden" name="token" value="{{.Token}}" />
</form>
<script>document.forms[0].submit()</script>
</body></html>
`))
type handoffHTMLData struct {
TenantID string
BaseDomain string
Token string
GeneratedAt time.Time
}
// HandleHandoff mints a tenant handoff token and returns an auto-submit HTML page.
// Route (wiring happens elsewhere): POST /api/accounts/{account_id}/tenants/{tenant_id}/handoff
//
// Auth: admin-key for now. User identity is provided by X-User-ID (temporary; session auth replaces this).
func HandleHandoff(reg *registry.TenantRegistry, tenantsDir string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if reg == nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
accountID := strings.TrimSpace(r.PathValue("account_id"))
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
if accountID == "" || tenantID == "" {
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
t, err := reg.Get(tenantID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if t == nil || strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
http.Error(w, "tenant not found", http.StatusNotFound)
return
}
// Temporary request-scoped identity until control plane session auth exists.
userID := strings.TrimSpace(r.Header.Get("X-User-ID"))
if userID == "" {
userID = strings.TrimSpace(r.Header.Get("X-User-Id"))
}
if userID == "" {
http.Error(w, "missing user identity", http.StatusBadRequest)
return
}
m, err := reg.GetMembership(accountID, userID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if m == nil {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
u, err := reg.GetUser(userID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if u == nil || strings.TrimSpace(u.Email) == "" {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
keyPath := filepath.Join(filepath.Clean(tenantsDir), tenantID, filepath.FromSlash(handoffKeyRelPath))
secret, err := os.ReadFile(keyPath)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
now := time.Now().UTC()
token, err := MintHandoffToken(secret, HandoffClaims{
TenantID: tenantID,
UserID: userID,
AccountID: accountID,
Email: u.Email,
Role: m.Role,
IssuedAt: now,
ExpiresAt: now.Add(defaultTTL),
})
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
baseDomain := deriveBaseDomain(r)
if baseDomain == "" {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_ = handoffHTMLTemplate.Execute(w, handoffHTMLData{
TenantID: tenantID,
BaseDomain: baseDomain,
Token: token,
GeneratedAt: now,
})
}
}
func deriveBaseDomain(r *http.Request) string {
if v := strings.TrimSpace(os.Getenv("CP_BASE_DOMAIN")); v != "" {
return v
}
host := strings.TrimSpace(r.Host)
if host == "" {
return ""
}
// Strip port if present.
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
} else {
// net.SplitHostPort errors if there's no port; that's fine.
if strings.Count(host, ":") > 1 {
// IPv6 without port isn't valid for our use case.
return ""
}
}
return host
}

View File

@@ -0,0 +1,116 @@
package handoff
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
t.Helper()
dir := t.TempDir()
reg, err := registry.NewTenantRegistry(dir)
if err != nil {
t.Fatalf("NewTenantRegistry: %v", err)
}
t.Cleanup(func() { _ = reg.Close() })
return reg
}
func TestHandoffHandler(t *testing.T) {
reg := newTestRegistry(t)
tenantsDir := t.TempDir()
accountID := "a_TEST"
tenantID := "t-TEST"
userID := "u_TEST"
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&registry.User{ID: userID, Email: "tech@example.com"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateMembership(&registry.AccountMembership{AccountID: accountID, UserID: userID, Role: registry.MemberRoleTech}); err != nil {
t.Fatal(err)
}
if err := reg.Create(&registry.Tenant{ID: tenantID, AccountID: accountID, DisplayName: "Client", State: registry.TenantStateActive}); err != nil {
t.Fatal(err)
}
secret := []byte("0123456789abcdef0123456789abcdef")
keyPath := filepath.Join(tenantsDir, tenantID, "secrets", "handoff.key")
if err := os.MkdirAll(filepath.Dir(keyPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(keyPath, secret, 0o600); err != nil {
t.Fatal(err)
}
mux := http.NewServeMux()
h := HandleHandoff(reg, tenantsDir)
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}/handoff", admin.AdminKeyMiddleware("secret-key", h))
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/tenants/"+tenantID+"/handoff", nil)
req.Host = "cloud.example.com"
req.Header.Set("X-Admin-Key", "secret-key")
req.Header.Set("X-User-ID", userID)
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
body := rec.Body.String()
wantAction := "https://" + tenantID + ".cloud.example.com/api/cloud/handoff/exchange"
if !regexp.MustCompile(regexp.QuoteMeta(wantAction)).MatchString(body) {
t.Fatalf("missing form action %q in body", wantAction)
}
re := regexp.MustCompile(`name="token" value="([^"]+)"`)
m := re.FindStringSubmatch(body)
if len(m) != 2 {
t.Fatalf("failed to extract token from HTML")
}
tokenStr := m[1]
var got jwtHandoffClaims
parsed, err := jwt.ParseWithClaims(
tokenStr,
&got,
func(t *jwt.Token) (any, error) { return secret, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(issuer),
jwt.WithAudience(tenantID),
)
if err != nil {
t.Fatalf("ParseWithClaims: %v", err)
}
if !parsed.Valid {
t.Fatalf("token valid = false, want true")
}
if got.Subject != userID {
t.Fatalf("sub = %q, want %q", got.Subject, userID)
}
if got.AccountID != accountID {
t.Fatalf("account_id = %q, want %q", got.AccountID, accountID)
}
if got.Email != "tech@example.com" {
t.Fatalf("email = %q, want %q", got.Email, "tech@example.com")
}
if got.Role != registry.MemberRoleTech {
t.Fatalf("role = %q, want %q", got.Role, registry.MemberRoleTech)
}
if got.ExpiresAt == nil || time.Until(got.ExpiresAt.Time) > 60*time.Second+2*time.Second {
t.Fatalf("exp looks wrong: %v", got.ExpiresAt)
}
}

View File

@@ -0,0 +1,131 @@
package handoff
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
const (
issuer = "pulse-cloud-control-plane"
defaultTTL = 60 * time.Second
handoffKeyLen = 32
)
// HandoffClaims are the logical claims that the control plane wants to assert
// when handing off an authenticated account user into a tenant container.
type HandoffClaims struct {
TenantID string
UserID string
AccountID string
Email string
Role registry.MemberRole
IssuedAt time.Time
ExpiresAt time.Time
JTI string
}
type jwtHandoffClaims struct {
AccountID string `json:"account_id"`
Email string `json:"email"`
Role registry.MemberRole `json:"role"`
jwt.RegisteredClaims
}
// GenerateHandoffKey returns 32 cryptographically random bytes suitable for HS256 signing.
// Intended for writing to /data/tenants/<tenant_id>/secrets/handoff.key (0600).
func GenerateHandoffKey() ([]byte, error) {
key := make([]byte, handoffKeyLen)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("generate handoff key: %w", err)
}
return key, nil
}
// MintHandoffToken mints a short-lived HS256 JWT signed with the per-tenant handoff.key.
//
// JWT registered claims:
// - iss: pulse-cloud-control-plane
// - aud: <tenant-id>
// - sub: <user-id>
// - iat, exp, jti
//
// Custom claims:
// - account_id, email, role
func MintHandoffToken(secret []byte, claims HandoffClaims) (string, error) {
if len(secret) == 0 {
return "", fmt.Errorf("secret is required")
}
claims.TenantID = sanitizeID(claims.TenantID)
claims.UserID = sanitizeID(claims.UserID)
claims.AccountID = sanitizeID(claims.AccountID)
if claims.TenantID == "" || claims.UserID == "" || claims.AccountID == "" {
return "", fmt.Errorf("tenantID, userID, and accountID are required")
}
if claims.Email == "" {
return "", fmt.Errorf("email is required")
}
now := time.Now().UTC()
if claims.IssuedAt.IsZero() {
claims.IssuedAt = now
}
claims.IssuedAt = claims.IssuedAt.UTC()
if claims.ExpiresAt.IsZero() {
claims.ExpiresAt = claims.IssuedAt.Add(defaultTTL)
}
claims.ExpiresAt = claims.ExpiresAt.UTC()
if !claims.ExpiresAt.After(claims.IssuedAt) {
return "", fmt.Errorf("expiresAt must be after issuedAt")
}
if claims.JTI == "" {
jti, err := randomJTI128()
if err != nil {
return "", err
}
claims.JTI = jti
}
jc := jwtHandoffClaims{
AccountID: claims.AccountID,
Email: claims.Email,
Role: claims.Role,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: claims.UserID,
Audience: jwt.ClaimStrings{claims.TenantID},
IssuedAt: jwt.NewNumericDate(claims.IssuedAt),
ExpiresAt: jwt.NewNumericDate(claims.ExpiresAt),
ID: claims.JTI,
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jc)
signed, err := tok.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign jwt: %w", err)
}
return signed, nil
}
func randomJTI128() (string, error) {
b := make([]byte, 16) // 128-bit
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", fmt.Errorf("generate jti: %w", err)
}
return hex.EncodeToString(b), nil
}
func sanitizeID(s string) string {
// IDs are generated by the registry; trimming prevents surprising mismatches.
return strings.TrimSpace(s)
}

View File

@@ -0,0 +1,167 @@
package handoff
import (
"errors"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
func TestMintAndVerify(t *testing.T) {
secret := []byte("0123456789abcdef0123456789abcdef") // 32 bytes
now := time.Now().UTC()
tokenStr, err := MintHandoffToken(secret, HandoffClaims{
TenantID: "t-TESTTENANT",
UserID: "u_TESTUSER",
AccountID: "a_TESTACCOUNT",
Email: "tech@example.com",
Role: registry.MemberRoleTech,
IssuedAt: now,
ExpiresAt: now.Add(60 * time.Second),
})
if err != nil {
t.Fatalf("MintHandoffToken: %v", err)
}
var got jwtHandoffClaims
parsed, err := jwt.ParseWithClaims(
tokenStr,
&got,
func(t *jwt.Token) (any, error) { return secret, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(issuer),
jwt.WithAudience("t-TESTTENANT"),
)
if err != nil {
t.Fatalf("ParseWithClaims: %v", err)
}
if !parsed.Valid {
t.Fatalf("token valid = false, want true")
}
if got.Subject != "u_TESTUSER" {
t.Fatalf("sub = %q, want %q", got.Subject, "u_TESTUSER")
}
if got.AccountID != "a_TESTACCOUNT" {
t.Fatalf("account_id = %q, want %q", got.AccountID, "a_TESTACCOUNT")
}
if got.Email != "tech@example.com" {
t.Fatalf("email = %q, want %q", got.Email, "tech@example.com")
}
if got.Role != registry.MemberRoleTech {
t.Fatalf("role = %q, want %q", got.Role, registry.MemberRoleTech)
}
if got.ID == "" {
t.Fatalf("jti empty")
}
if got.IssuedAt == nil || got.ExpiresAt == nil {
t.Fatalf("missing iat/exp")
}
if got.ExpiresAt.Time.Sub(got.IssuedAt.Time) != 60*time.Second {
t.Fatalf("exp-iat = %v, want %v", got.ExpiresAt.Time.Sub(got.IssuedAt.Time), 60*time.Second)
}
}
func TestExpiredToken(t *testing.T) {
secret := []byte("0123456789abcdef0123456789abcdef")
past := time.Now().UTC().Add(-2 * time.Minute)
tokenStr, err := MintHandoffToken(secret, HandoffClaims{
TenantID: "t-EXPIRED",
UserID: "u_USER",
AccountID: "a_ACCOUNT",
Email: "x@example.com",
Role: registry.MemberRoleReadOnly,
IssuedAt: past,
ExpiresAt: past.Add(30 * time.Second),
JTI: "deadbeefdeadbeefdeadbeefdeadbeef",
})
if err != nil {
t.Fatalf("MintHandoffToken: %v", err)
}
var got jwtHandoffClaims
_, err = jwt.ParseWithClaims(
tokenStr,
&got,
func(t *jwt.Token) (any, error) { return secret, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(issuer),
jwt.WithAudience("t-EXPIRED"),
)
if err == nil {
t.Fatalf("expected expired token error")
}
if !errors.Is(err, jwt.ErrTokenExpired) {
// Leave room for wrapped validation errors.
t.Fatalf("error = %v, want ErrTokenExpired", err)
}
}
func TestWrongSecret(t *testing.T) {
secretA := []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
secretB := []byte("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
now := time.Now().UTC()
tokenStr, err := MintHandoffToken(secretA, HandoffClaims{
TenantID: "t-TENANT",
UserID: "u_USER",
AccountID: "a_ACCOUNT",
Email: "x@example.com",
Role: registry.MemberRoleAdmin,
IssuedAt: now,
ExpiresAt: now.Add(60 * time.Second),
JTI: "0123456789abcdef0123456789abcdef",
})
if err != nil {
t.Fatalf("MintHandoffToken: %v", err)
}
var got jwtHandoffClaims
_, err = jwt.ParseWithClaims(
tokenStr,
&got,
func(t *jwt.Token) (any, error) { return secretB, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(issuer),
jwt.WithAudience("t-TENANT"),
)
if err == nil {
t.Fatalf("expected signature verification failure")
}
}
func TestWrongAudience(t *testing.T) {
secret := []byte("0123456789abcdef0123456789abcdef")
now := time.Now().UTC()
tokenStr, err := MintHandoffToken(secret, HandoffClaims{
TenantID: "t-A",
UserID: "u_USER",
AccountID: "a_ACCOUNT",
Email: "x@example.com",
Role: registry.MemberRoleOwner,
IssuedAt: now,
ExpiresAt: now.Add(60 * time.Second),
JTI: "0123456789abcdef0123456789abcdef",
})
if err != nil {
t.Fatalf("MintHandoffToken: %v", err)
}
var got jwtHandoffClaims
_, err = jwt.ParseWithClaims(
tokenStr,
&got,
func(t *jwt.Token) (any, error) { return secret, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(issuer),
jwt.WithAudience("t-B"),
)
if err == nil {
t.Fatalf("expected wrong audience error")
}
}

View File

@@ -0,0 +1,202 @@
package portal
import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
type accountInfo struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
Kind registry.AccountKind `json:"kind"`
}
type workspaceSummaryItem struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
State registry.TenantState `json:"state"`
HealthCheckOK bool `json:"health_check_ok"`
LastHealthCheck *time.Time `json:"last_health_check"`
CreatedAt time.Time `json:"created_at"`
}
type dashboardSummary struct {
Total int `json:"total"`
Active int `json:"active"`
Healthy int `json:"healthy"`
Unhealthy int `json:"unhealthy"`
Suspended int `json:"suspended"`
}
type dashboardResponse struct {
Account accountInfo `json:"account"`
Workspaces []workspaceSummaryItem `json:"workspaces"`
Summary dashboardSummary `json:"summary"`
}
type workspaceDetailResponse struct {
Account accountInfo `json:"account"`
Workspace *registry.Tenant `json:"workspace"`
}
func accountIDFromRequest(r *http.Request) string {
if r == nil {
return ""
}
if v := strings.TrimSpace(r.URL.Query().Get("account_id")); v != "" {
return v
}
// Convenience for future callers; spec says query param is fine for now.
if v := strings.TrimSpace(r.Header.Get("X-Account-ID")); v != "" {
return v
}
if v := strings.TrimSpace(r.Header.Get("X-Account-Id")); v != "" {
return v
}
return ""
}
// HandlePortalDashboard returns a portal-oriented dashboard response for an account.
// Route: GET /api/portal/dashboard?account_id=...
//
// Auth: admin-key for now (session auth in M-4).
func HandlePortalDashboard(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if reg == nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
accountID := accountIDFromRequest(r)
if accountID == "" {
http.Error(w, "missing account_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
tenants, err := reg.ListByAccountID(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if tenants == nil {
tenants = []*registry.Tenant{}
}
resp := dashboardResponse{
Account: accountInfo{
ID: a.ID,
DisplayName: a.DisplayName,
Kind: a.Kind,
},
Workspaces: make([]workspaceSummaryItem, 0, len(tenants)),
}
for _, t := range tenants {
if t == nil {
continue
}
resp.Workspaces = append(resp.Workspaces, workspaceSummaryItem{
ID: t.ID,
DisplayName: t.DisplayName,
State: t.State,
HealthCheckOK: t.HealthCheckOK,
LastHealthCheck: t.LastHealthCheck,
CreatedAt: t.CreatedAt,
})
resp.Summary.Total++
switch t.State {
case registry.TenantStateActive:
resp.Summary.Active++
if t.HealthCheckOK {
resp.Summary.Healthy++
} else {
resp.Summary.Unhealthy++
}
case registry.TenantStateSuspended:
resp.Summary.Suspended++
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}
// HandlePortalWorkspaceDetail returns a portal-oriented detail response for a single tenant.
// Route: GET /api/portal/workspaces/{tenant_id}?account_id=...
//
// Auth: admin-key for now (session auth in M-4).
func HandlePortalWorkspaceDetail(reg *registry.TenantRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if reg == nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
accountID := accountIDFromRequest(r)
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
if accountID == "" || tenantID == "" {
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
return
}
a, err := reg.GetAccount(accountID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if a == nil {
http.Error(w, "account not found", http.StatusNotFound)
return
}
t, err := reg.Get(tenantID)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if t == nil || strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
http.Error(w, "tenant not found", http.StatusNotFound)
return
}
resp := workspaceDetailResponse{
Account: accountInfo{
ID: a.ID,
DisplayName: a.DisplayName,
Kind: a.Kind,
},
Workspace: t,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}

View File

@@ -0,0 +1,310 @@
package portal
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sort"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
)
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
t.Helper()
dir := t.TempDir()
reg, err := registry.NewTenantRegistry(dir)
if err != nil {
t.Fatalf("NewTenantRegistry: %v", err)
}
t.Cleanup(func() { _ = reg.Close() })
return reg
}
func newTestMux(reg *registry.TenantRegistry) *http.ServeMux {
mux := http.NewServeMux()
mux.Handle("/api/portal/dashboard", admin.AdminKeyMiddleware("secret-key", HandlePortalDashboard(reg)))
mux.Handle("/api/portal/workspaces/{tenant_id}", admin.AdminKeyMiddleware("secret-key", HandlePortalWorkspaceDetail(reg)))
return mux
}
func doRequest(t *testing.T, h http.Handler, req *http.Request) *httptest.ResponseRecorder {
t.Helper()
req.Header.Set("X-Admin-Key", "secret-key")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
return rec
}
type dashboardResp struct {
Account struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
Kind registry.AccountKind `json:"kind"`
} `json:"account"`
Workspaces []struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
State registry.TenantState `json:"state"`
HealthCheckOK bool `json:"health_check_ok"`
LastHealthCheck *time.Time `json:"last_health_check"`
CreatedAt time.Time `json:"created_at"`
} `json:"workspaces"`
Summary struct {
Total int `json:"total"`
Active int `json:"active"`
Healthy int `json:"healthy"`
Unhealthy int `json:"unhealthy"`
Suspended int `json:"suspended"`
} `json:"summary"`
}
func TestPortalDashboard(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Example MSP"}); err != nil {
t.Fatal(err)
}
tenantActiveID, err := registry.GenerateTenantID()
if err != nil {
t.Fatal(err)
}
tenantSuspendedID, err := registry.GenerateTenantID()
if err != nil {
t.Fatal(err)
}
created1 := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC)
created2 := time.Date(2026, 2, 10, 11, 0, 0, 0, time.UTC)
lastCheck := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
if err := reg.Create(&registry.Tenant{
ID: tenantActiveID,
AccountID: accountID,
DisplayName: "Acme Dental",
State: registry.TenantStateActive,
CreatedAt: created1,
LastHealthCheck: &lastCheck,
HealthCheckOK: true,
}); err != nil {
t.Fatal(err)
}
if err := reg.Create(&registry.Tenant{
ID: tenantSuspendedID,
AccountID: accountID,
DisplayName: "Suspended Workspace",
State: registry.TenantStateSuspended,
CreatedAt: created2,
HealthCheckOK: false,
}); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/api/portal/dashboard?account_id="+accountID, nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
var resp dashboardResp
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v (body=%q)", err, rec.Body.String())
}
if resp.Account.ID != accountID {
t.Fatalf("account.id = %q, want %q", resp.Account.ID, accountID)
}
if resp.Account.DisplayName != "Example MSP" {
t.Fatalf("account.display_name = %q, want %q", resp.Account.DisplayName, "Example MSP")
}
if resp.Account.Kind != registry.AccountKindMSP {
t.Fatalf("account.kind = %q, want %q", resp.Account.Kind, registry.AccountKindMSP)
}
if len(resp.Workspaces) != 2 {
t.Fatalf("workspaces len = %d, want %d", len(resp.Workspaces), 2)
}
// Make assertions order-independent.
sort.Slice(resp.Workspaces, func(i, j int) bool { return resp.Workspaces[i].ID < resp.Workspaces[j].ID })
wsByID := map[string]dashboardRespWorkspace{}
// local helper type for easier indexing
for _, ws := range resp.Workspaces {
wsByID[ws.ID] = dashboardRespWorkspace{
ID: ws.ID,
DisplayName: ws.DisplayName,
State: ws.State,
HealthCheckOK: ws.HealthCheckOK,
LastHealthCheck: ws.LastHealthCheck,
CreatedAt: ws.CreatedAt,
}
}
active := wsByID[tenantActiveID]
if active.ID == "" {
t.Fatalf("missing active workspace id %q", tenantActiveID)
}
if active.DisplayName != "Acme Dental" {
t.Fatalf("active.display_name = %q, want %q", active.DisplayName, "Acme Dental")
}
if active.State != registry.TenantStateActive {
t.Fatalf("active.state = %q, want %q", active.State, registry.TenantStateActive)
}
if !active.HealthCheckOK {
t.Fatalf("active.health_check_ok = false, want true")
}
if active.LastHealthCheck == nil || !active.LastHealthCheck.Equal(lastCheck) {
t.Fatalf("active.last_health_check = %v, want %v", active.LastHealthCheck, lastCheck)
}
if !active.CreatedAt.Equal(created1) {
t.Fatalf("active.created_at = %v, want %v", active.CreatedAt, created1)
}
susp := wsByID[tenantSuspendedID]
if susp.ID == "" {
t.Fatalf("missing suspended workspace id %q", tenantSuspendedID)
}
if susp.State != registry.TenantStateSuspended {
t.Fatalf("suspended.state = %q, want %q", susp.State, registry.TenantStateSuspended)
}
if resp.Summary.Total != 2 {
t.Fatalf("summary.total = %d, want %d", resp.Summary.Total, 2)
}
if resp.Summary.Active != 1 {
t.Fatalf("summary.active = %d, want %d", resp.Summary.Active, 1)
}
if resp.Summary.Healthy != 1 {
t.Fatalf("summary.healthy = %d, want %d", resp.Summary.Healthy, 1)
}
if resp.Summary.Unhealthy != 0 {
t.Fatalf("summary.unhealthy = %d, want %d", resp.Summary.Unhealthy, 0)
}
if resp.Summary.Suspended != 1 {
t.Fatalf("summary.suspended = %d, want %d", resp.Summary.Suspended, 1)
}
}
type dashboardRespWorkspace struct {
ID string
DisplayName string
State registry.TenantState
HealthCheckOK bool
LastHealthCheck *time.Time
CreatedAt time.Time
}
func TestPortalDashboardEmpty(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Empty MSP"}); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/api/portal/dashboard?account_id="+accountID, nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
var resp dashboardResp
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v (body=%q)", err, rec.Body.String())
}
if len(resp.Workspaces) != 0 {
t.Fatalf("workspaces len = %d, want %d", len(resp.Workspaces), 0)
}
if resp.Summary.Total != 0 || resp.Summary.Active != 0 || resp.Summary.Healthy != 0 || resp.Summary.Unhealthy != 0 || resp.Summary.Suspended != 0 {
t.Fatalf("summary = %+v, want all zeros", resp.Summary)
}
}
func TestPortalWorkspaceDetail(t *testing.T) {
reg := newTestRegistry(t)
mux := newTestMux(reg)
accountID, err := registry.GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&registry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Example MSP"}); err != nil {
t.Fatal(err)
}
tenantID, err := registry.GenerateTenantID()
if err != nil {
t.Fatal(err)
}
created := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC)
lastCheck := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
if err := reg.Create(&registry.Tenant{
ID: tenantID,
AccountID: accountID,
DisplayName: "Acme Dental",
State: registry.TenantStateActive,
CreatedAt: created,
LastHealthCheck: &lastCheck,
HealthCheckOK: true,
}); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/api/portal/workspaces/"+tenantID+"?account_id="+accountID, nil)
rec := doRequest(t, mux, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Account struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
Kind registry.AccountKind `json:"kind"`
} `json:"account"`
Workspace registry.Tenant `json:"workspace"`
}
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v (body=%q)", err, rec.Body.String())
}
if resp.Account.ID != accountID {
t.Fatalf("account.id = %q, want %q", resp.Account.ID, accountID)
}
if resp.Workspace.ID != tenantID {
t.Fatalf("workspace.id = %q, want %q", resp.Workspace.ID, tenantID)
}
if resp.Workspace.AccountID != accountID {
t.Fatalf("workspace.account_id = %q, want %q", resp.Workspace.AccountID, accountID)
}
if resp.Workspace.DisplayName != "Acme Dental" {
t.Fatalf("workspace.display_name = %q, want %q", resp.Workspace.DisplayName, "Acme Dental")
}
if resp.Workspace.State != registry.TenantStateActive {
t.Fatalf("workspace.state = %q, want %q", resp.Workspace.State, registry.TenantStateActive)
}
if !resp.Workspace.HealthCheckOK {
t.Fatalf("workspace.health_check_ok = false, want true")
}
if resp.Workspace.LastHealthCheck == nil || !resp.Workspace.LastHealthCheck.Equal(lastCheck) {
t.Fatalf("workspace.last_health_check = %v, want %v", resp.Workspace.LastHealthCheck, lastCheck)
}
if !resp.Workspace.CreatedAt.Equal(created) {
t.Fatalf("workspace.created_at = %v, want %v", resp.Workspace.CreatedAt, created)
}
}

View File

@@ -7,6 +7,47 @@ import (
"time"
)
type AccountKind string
const (
AccountKindIndividual AccountKind = "individual"
AccountKindMSP AccountKind = "msp"
)
// Account represents an account record in the control plane registry.
type Account struct {
ID string `json:"id"`
Kind AccountKind `json:"kind"`
DisplayName string `json:"display_name"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// User represents a user record in the control plane registry.
type User struct {
ID string `json:"id"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
}
type MemberRole string
const (
MemberRoleOwner MemberRole = "owner"
MemberRoleAdmin MemberRole = "admin"
MemberRoleTech MemberRole = "tech"
MemberRoleReadOnly MemberRole = "read_only"
)
// AccountMembership represents a mapping between an account and a user.
type AccountMembership struct {
AccountID string `json:"account_id"`
UserID string `json:"user_id"`
Role MemberRole `json:"role"`
CreatedAt time.Time `json:"created_at"`
}
// TenantState represents the lifecycle state of a tenant.
type TenantState string
@@ -15,12 +56,14 @@ const (
TenantStateActive TenantState = "active"
TenantStateSuspended TenantState = "suspended"
TenantStateCanceled TenantState = "canceled"
TenantStateDeleting TenantState = "deleting"
TenantStateDeleted TenantState = "deleted"
)
// Tenant represents a Cloud tenant record in the registry.
type Tenant struct {
ID string `json:"id"`
AccountID string `json:"account_id"`
Email string `json:"email"`
DisplayName string `json:"display_name"`
State TenantState `json:"state"`
@@ -37,6 +80,20 @@ type Tenant struct {
HealthCheckOK bool `json:"health_check_ok"`
}
// StripeAccount maps a control-plane account to a single Stripe customer +
// subscription for consolidated (MSP-style) billing.
type StripeAccount struct {
AccountID string `json:"account_id"`
StripeCustomerID string `json:"stripe_customer_id"`
StripeSubscriptionID string `json:"stripe_subscription_id"`
StripeSubItemWorkspacesID string `json:"stripe_sub_item_workspaces_id"`
PlanVersion string `json:"plan_version"`
SubscriptionState string `json:"subscription_state"` // trial, active, past_due, canceled
TrialEndsAt *int64 `json:"trial_ends_at"`
CurrentPeriodEnd *int64 `json:"current_period_end"`
UpdatedAt int64 `json:"updated_at"`
}
// crockfordBase32 is the Crockford base32 alphabet (excludes I, L, O, U).
const crockfordBase32 = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
@@ -54,3 +111,33 @@ func GenerateTenantID() (string, error) {
}
return sb.String(), nil
}
// GenerateAccountID returns an account ID of the form "a_" followed by 10 random
// Crockford base32 characters (50 bits of entropy).
func GenerateAccountID() (string, error) {
b := make([]byte, 10)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate account id: %w", err)
}
var sb strings.Builder
sb.WriteString("a_")
for _, v := range b {
sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)])
}
return sb.String(), nil
}
// GenerateUserID returns a user ID of the form "u_" followed by 10 random
// Crockford base32 characters (50 bits of entropy).
func GenerateUserID() (string, error) {
b := make([]byte, 10)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate user id: %w", err)
}
var sb strings.Builder
sb.WriteString("u_")
for _, v := range b {
sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)])
}
return sb.String(), nil
}

View File

@@ -6,6 +6,7 @@ import (
"net/url"
"os"
"path/filepath"
"strings"
"time"
_ "modernc.org/sqlite"
@@ -51,6 +52,7 @@ func (r *TenantRegistry) initSchema() error {
schema := `
CREATE TABLE IF NOT EXISTS tenants (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT 'provisioning',
@@ -68,13 +70,106 @@ func (r *TenantRegistry) initSchema() error {
);
CREATE INDEX IF NOT EXISTS idx_tenants_state ON tenants(state);
CREATE INDEX IF NOT EXISTS idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
kind TEXT NOT NULL DEFAULT 'individual',
display_name TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS stripe_accounts (
account_id TEXT PRIMARY KEY,
stripe_customer_id TEXT NOT NULL UNIQUE,
stripe_subscription_id TEXT,
stripe_sub_item_workspaces_id TEXT,
plan_version TEXT NOT NULL DEFAULT '',
subscription_state TEXT NOT NULL DEFAULT 'trial',
trial_ends_at INTEGER,
current_period_end INTEGER,
updated_at INTEGER NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id)
);
CREATE INDEX IF NOT EXISTS idx_stripe_accounts_customer ON stripe_accounts(stripe_customer_id);
CREATE TABLE IF NOT EXISTS stripe_events (
stripe_event_id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
received_at INTEGER NOT NULL,
processed_at INTEGER,
processing_error TEXT
);
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
created_at INTEGER NOT NULL,
last_login_at INTEGER
);
CREATE TABLE IF NOT EXISTS account_memberships (
account_id TEXT NOT NULL,
user_id TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'tech',
created_at INTEGER NOT NULL,
PRIMARY KEY (account_id, user_id),
FOREIGN KEY (account_id) REFERENCES accounts(id),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS idx_memberships_user_id ON account_memberships(user_id);
`
if _, err := r.db.Exec(schema); err != nil {
return fmt.Errorf("init tenant registry schema: %w", err)
}
// Migration: add account_id to tenants if not present.
// (SQLite makes it awkward to add FK constraints via ALTER TABLE, and FK
// enforcement is off by default; this keeps the change backwards-compatible.)
hasAccountID, err := r.tenantsHasColumn("account_id")
if err != nil {
return err
}
if !hasAccountID {
if _, err := r.db.Exec(`ALTER TABLE tenants ADD COLUMN account_id TEXT NOT NULL DEFAULT ''`); err != nil {
return fmt.Errorf("migrate tenants: add account_id: %w", err)
}
}
if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_tenants_account_id ON tenants(account_id)`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_tenants_account_id: %w", err)
}
return nil
}
func (r *TenantRegistry) tenantsHasColumn(name string) (bool, error) {
rows, err := r.db.Query(`PRAGMA table_info(tenants)`)
if err != nil {
return false, fmt.Errorf("pragma table_info(tenants): %w", err)
}
defer rows.Close()
for rows.Next() {
var (
cid int
colName string
colType string
notNull int
dflt sql.NullString
pk int
)
if err := rows.Scan(&cid, &colName, &colType, &notNull, &dflt, &pk); err != nil {
return false, fmt.Errorf("scan table_info(tenants): %w", err)
}
if colName == name {
return true, nil
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("iterate table_info(tenants): %w", err)
}
return false, nil
}
// Ping checks database connectivity (used for readiness probes).
func (r *TenantRegistry) Ping() error {
return r.db.Ping()
@@ -101,12 +196,12 @@ func (r *TenantRegistry) Create(t *Tenant) error {
_, err := r.db.Exec(`
INSERT INTO tenants (
id, email, display_name, state,
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
t.ID, t.Email, t.DisplayName, string(t.State),
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
t.ID, t.AccountID, t.Email, t.DisplayName, string(t.State),
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
t.CreatedAt.Unix(), t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
@@ -120,7 +215,7 @@ func (r *TenantRegistry) Create(t *Tenant) error {
// Get retrieves a tenant by ID.
func (r *TenantRegistry) Get(id string) (*Tenant, error) {
row := r.db.QueryRow(`SELECT
id, email, display_name, state,
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
@@ -131,7 +226,7 @@ func (r *TenantRegistry) Get(id string) (*Tenant, error) {
// GetByStripeCustomerID retrieves a tenant by Stripe customer ID.
func (r *TenantRegistry) GetByStripeCustomerID(customerID string) (*Tenant, error) {
row := r.db.QueryRow(`SELECT
id, email, display_name, state,
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
@@ -148,12 +243,12 @@ func (r *TenantRegistry) Update(t *Tenant) error {
res, err := r.db.Exec(`
UPDATE tenants SET
email = ?, display_name = ?, state = ?,
account_id = ?, email = ?, display_name = ?, state = ?,
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_price_id = ?,
plan_version = ?, container_id = ?, current_image_digest = ?, desired_image_digest = ?,
updated_at = ?, last_health_check = ?, health_check_ok = ?
WHERE id = ?`,
t.Email, t.DisplayName, string(t.State),
t.AccountID, t.Email, t.DisplayName, string(t.State),
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
@@ -172,7 +267,7 @@ func (r *TenantRegistry) Update(t *Tenant) error {
// List returns all tenants.
func (r *TenantRegistry) List() ([]*Tenant, error) {
rows, err := r.db.Query(`SELECT
id, email, display_name, state,
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
@@ -187,7 +282,7 @@ func (r *TenantRegistry) List() ([]*Tenant, error) {
// ListByState returns all tenants matching the given state.
func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
rows, err := r.db.Query(`SELECT
id, email, display_name, state,
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
@@ -199,6 +294,21 @@ func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
return scanTenants(rows)
}
// ListByAccountID returns all tenants belonging to the given account ID.
func (r *TenantRegistry) ListByAccountID(accountID string) ([]*Tenant, error) {
rows, err := r.db.Query(`SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants WHERE account_id = ? ORDER BY created_at DESC`, accountID)
if err != nil {
return nil, fmt.Errorf("list tenants by account id: %w", err)
}
defer rows.Close()
return scanTenants(rows)
}
// CountByState returns a map of state -> count.
func (r *TenantRegistry) CountByState() (map[TenantState]int, error) {
rows, err := r.db.Query(`SELECT state, COUNT(*) FROM tenants GROUP BY state`)
@@ -244,7 +354,7 @@ func scanTenant(s scanner) (*Tenant, error) {
var healthOK int
err := s.Scan(
&t.ID, &t.Email, &t.DisplayName, &state,
&t.ID, &t.AccountID, &t.Email, &t.DisplayName, &state,
&t.StripeCustomerID, &t.StripeSubscriptionID, &t.StripePriceID,
&t.PlanVersion, &t.ContainerID, &t.CurrentImageDigest, &t.DesiredImageDigest,
&createdAt, &updatedAt, &lastHealthCheck, &healthOK,
@@ -279,6 +389,474 @@ func scanTenants(rows *sql.Rows) ([]*Tenant, error) {
return tenants, rows.Err()
}
// CreateAccount inserts a new account record.
func (r *TenantRegistry) CreateAccount(a *Account) error {
if a == nil {
return fmt.Errorf("account is nil")
}
now := time.Now().UTC()
if a.CreatedAt.IsZero() {
a.CreatedAt = now
}
a.UpdatedAt = now
kind := string(a.Kind)
if kind == "" {
kind = string(AccountKindIndividual)
}
_, err := r.db.Exec(`
INSERT INTO accounts (
id, kind, display_name, created_at, updated_at
) VALUES (?, ?, ?, ?, ?)`,
a.ID, kind, a.DisplayName, a.CreatedAt.Unix(), a.UpdatedAt.Unix(),
)
if err != nil {
return fmt.Errorf("create account: %w", err)
}
a.Kind = AccountKind(kind)
return nil
}
// GetAccount retrieves an account by ID.
func (r *TenantRegistry) GetAccount(id string) (*Account, error) {
row := r.db.QueryRow(`SELECT
id, kind, display_name, created_at, updated_at
FROM accounts WHERE id = ?`, id)
return scanAccount(row)
}
// UpdateAccount modifies an existing account record.
func (r *TenantRegistry) UpdateAccount(a *Account) error {
if a == nil {
return fmt.Errorf("account is nil")
}
a.UpdatedAt = time.Now().UTC()
kind := string(a.Kind)
if kind == "" {
kind = string(AccountKindIndividual)
}
res, err := r.db.Exec(`
UPDATE accounts SET
kind = ?, display_name = ?, updated_at = ?
WHERE id = ?`,
kind, a.DisplayName, a.UpdatedAt.Unix(),
a.ID,
)
if err != nil {
return fmt.Errorf("update account: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return fmt.Errorf("account %q not found", a.ID)
}
a.Kind = AccountKind(kind)
return nil
}
// ListAccounts returns all accounts.
func (r *TenantRegistry) ListAccounts() ([]*Account, error) {
rows, err := r.db.Query(`SELECT
id, kind, display_name, created_at, updated_at
FROM accounts ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("list accounts: %w", err)
}
defer rows.Close()
return scanAccounts(rows)
}
// CreateUser inserts a new user record.
func (r *TenantRegistry) CreateUser(u *User) error {
if u == nil {
return fmt.Errorf("user is nil")
}
now := time.Now().UTC()
if u.CreatedAt.IsZero() {
u.CreatedAt = now
}
_, err := r.db.Exec(`
INSERT INTO users (
id, email, created_at, last_login_at
) VALUES (?, ?, ?, ?)`,
u.ID, u.Email, u.CreatedAt.Unix(), nullableTimeUnix(u.LastLoginAt),
)
if err != nil {
return fmt.Errorf("create user: %w", err)
}
return nil
}
// GetUser retrieves a user by ID.
func (r *TenantRegistry) GetUser(id string) (*User, error) {
row := r.db.QueryRow(`SELECT
id, email, created_at, last_login_at
FROM users WHERE id = ?`, id)
return scanUser(row)
}
// GetUserByEmail retrieves a user by email.
func (r *TenantRegistry) GetUserByEmail(email string) (*User, error) {
row := r.db.QueryRow(`SELECT
id, email, created_at, last_login_at
FROM users WHERE email = ?`, email)
return scanUser(row)
}
// UpdateUserLastLogin sets last_login_at for the given user ID to the current time.
func (r *TenantRegistry) UpdateUserLastLogin(id string) error {
now := time.Now().UTC()
res, err := r.db.Exec(`UPDATE users SET last_login_at = ? WHERE id = ?`, now.Unix(), id)
if err != nil {
return fmt.Errorf("update user last login: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return fmt.Errorf("user %q not found", id)
}
return nil
}
// CreateMembership inserts a new membership record.
func (r *TenantRegistry) CreateMembership(m *AccountMembership) error {
if m == nil {
return fmt.Errorf("membership is nil")
}
now := time.Now().UTC()
if m.CreatedAt.IsZero() {
m.CreatedAt = now
}
role := string(m.Role)
if role == "" {
role = string(MemberRoleTech)
}
_, err := r.db.Exec(`
INSERT INTO account_memberships (
account_id, user_id, role, created_at
) VALUES (?, ?, ?, ?)`,
m.AccountID, m.UserID, role, m.CreatedAt.Unix(),
)
if err != nil {
return fmt.Errorf("create membership: %w", err)
}
m.Role = MemberRole(role)
return nil
}
// GetMembership retrieves a membership record by account ID and user ID.
func (r *TenantRegistry) GetMembership(accountID, userID string) (*AccountMembership, error) {
row := r.db.QueryRow(`SELECT
account_id, user_id, role, created_at
FROM account_memberships
WHERE account_id = ? AND user_id = ?`, accountID, userID)
return scanMembership(row)
}
// ListMembersByAccount returns all membership records for a given account ID.
func (r *TenantRegistry) ListMembersByAccount(accountID string) ([]*AccountMembership, error) {
rows, err := r.db.Query(`SELECT
account_id, user_id, role, created_at
FROM account_memberships
WHERE account_id = ?
ORDER BY created_at DESC`, accountID)
if err != nil {
return nil, fmt.Errorf("list members by account: %w", err)
}
defer rows.Close()
return scanMemberships(rows)
}
// ListAccountsByUser returns account IDs for all accounts the given user belongs to.
func (r *TenantRegistry) ListAccountsByUser(userID string) ([]string, error) {
rows, err := r.db.Query(`SELECT account_id FROM account_memberships WHERE user_id = ? ORDER BY created_at DESC`, userID)
if err != nil {
return nil, fmt.Errorf("list accounts by user: %w", err)
}
defer rows.Close()
var accountIDs []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan account id: %w", err)
}
accountIDs = append(accountIDs, id)
}
return accountIDs, rows.Err()
}
// UpdateMembershipRole updates a membership role.
func (r *TenantRegistry) UpdateMembershipRole(accountID, userID string, role MemberRole) error {
res, err := r.db.Exec(`UPDATE account_memberships SET role = ? WHERE account_id = ? AND user_id = ?`, string(role), accountID, userID)
if err != nil {
return fmt.Errorf("update membership role: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
}
return nil
}
// DeleteMembership deletes a membership record.
func (r *TenantRegistry) DeleteMembership(accountID, userID string) error {
res, err := r.db.Exec(`DELETE FROM account_memberships WHERE account_id = ? AND user_id = ?`, accountID, userID)
if err != nil {
return fmt.Errorf("delete membership: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
}
return nil
}
// CreateStripeAccount inserts a new StripeAccount mapping row.
func (r *TenantRegistry) CreateStripeAccount(sa *StripeAccount) error {
if sa == nil {
return fmt.Errorf("stripe account is nil")
}
sa.AccountID = strings.TrimSpace(sa.AccountID)
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
sa.PlanVersion = strings.TrimSpace(sa.PlanVersion)
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
if sa.AccountID == "" {
return fmt.Errorf("missing account id")
}
if sa.StripeCustomerID == "" {
return fmt.Errorf("missing stripe customer id")
}
if sa.SubscriptionState == "" {
sa.SubscriptionState = "trial"
}
if sa.UpdatedAt == 0 {
sa.UpdatedAt = time.Now().UTC().Unix()
}
_, err := r.db.Exec(`
INSERT INTO stripe_accounts (
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, trial_ends_at, current_period_end, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
sa.AccountID,
sa.StripeCustomerID,
nullableString(sa.StripeSubscriptionID),
nullableString(sa.StripeSubItemWorkspacesID),
sa.PlanVersion,
sa.SubscriptionState,
nullableInt64Ptr(sa.TrialEndsAt),
nullableInt64Ptr(sa.CurrentPeriodEnd),
sa.UpdatedAt,
)
if err != nil {
return fmt.Errorf("create stripe account: %w", err)
}
return nil
}
// GetStripeAccount retrieves the StripeAccount row by account ID.
func (r *TenantRegistry) GetStripeAccount(accountID string) (*StripeAccount, error) {
row := r.db.QueryRow(`SELECT
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, trial_ends_at, current_period_end, updated_at
FROM stripe_accounts WHERE account_id = ?`, strings.TrimSpace(accountID))
return scanStripeAccount(row)
}
// GetStripeAccountByCustomerID retrieves the StripeAccount row by Stripe customer ID.
func (r *TenantRegistry) GetStripeAccountByCustomerID(customerID string) (*StripeAccount, error) {
row := r.db.QueryRow(`SELECT
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, trial_ends_at, current_period_end, updated_at
FROM stripe_accounts WHERE stripe_customer_id = ?`, strings.TrimSpace(customerID))
return scanStripeAccount(row)
}
// UpdateStripeAccount modifies an existing StripeAccount row.
func (r *TenantRegistry) UpdateStripeAccount(sa *StripeAccount) error {
if sa == nil {
return fmt.Errorf("stripe account is nil")
}
sa.AccountID = strings.TrimSpace(sa.AccountID)
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
sa.PlanVersion = strings.TrimSpace(sa.PlanVersion)
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
if sa.AccountID == "" {
return fmt.Errorf("missing account id")
}
if sa.StripeCustomerID == "" {
return fmt.Errorf("missing stripe customer id")
}
if sa.SubscriptionState == "" {
sa.SubscriptionState = "trial"
}
sa.UpdatedAt = time.Now().UTC().Unix()
res, err := r.db.Exec(`
UPDATE stripe_accounts SET
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_sub_item_workspaces_id = ?,
plan_version = ?, subscription_state = ?, trial_ends_at = ?, current_period_end = ?, updated_at = ?
WHERE account_id = ?`,
sa.StripeCustomerID,
nullableString(sa.StripeSubscriptionID),
nullableString(sa.StripeSubItemWorkspacesID),
sa.PlanVersion,
sa.SubscriptionState,
nullableInt64Ptr(sa.TrialEndsAt),
nullableInt64Ptr(sa.CurrentPeriodEnd),
sa.UpdatedAt,
sa.AccountID,
)
if err != nil {
return fmt.Errorf("update stripe account: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return fmt.Errorf("stripe account %q not found", sa.AccountID)
}
return nil
}
// RecordStripeEvent inserts a webhook event record and returns true if the
// event was already recorded (duplicate Stripe delivery).
func (r *TenantRegistry) RecordStripeEvent(eventID, eventType string) (alreadyProcessed bool, err error) {
eventID = strings.TrimSpace(eventID)
eventType = strings.TrimSpace(eventType)
if eventID == "" {
return false, fmt.Errorf("missing stripe event id")
}
if eventType == "" {
return false, fmt.Errorf("missing stripe event type")
}
// INSERT OR IGNORE avoids driver-specific error parsing for duplicates.
res, err := r.db.Exec(`
INSERT OR IGNORE INTO stripe_events (
stripe_event_id, event_type, received_at, processed_at, processing_error
) VALUES (?, ?, ?, NULL, NULL)`,
eventID, eventType, time.Now().UTC().Unix(),
)
if err != nil {
return false, fmt.Errorf("record stripe event: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return true, nil
}
return false, nil
}
// MarkStripeEventProcessed marks a previously recorded event as processed.
// processingError is stored (nullable) for troubleshooting.
func (r *TenantRegistry) MarkStripeEventProcessed(eventID string, processingError string) error {
eventID = strings.TrimSpace(eventID)
if eventID == "" {
return fmt.Errorf("missing stripe event id")
}
processingError = strings.TrimSpace(processingError)
res, err := r.db.Exec(`
UPDATE stripe_events SET
processed_at = ?, processing_error = ?
WHERE stripe_event_id = ?`,
time.Now().UTC().Unix(),
nullableString(processingError),
eventID,
)
if err != nil {
return fmt.Errorf("mark stripe event processed: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return fmt.Errorf("stripe event %q not found", eventID)
}
return nil
}
func scanAccount(s scanner) (*Account, error) {
var a Account
var kind string
var createdAt, updatedAt int64
if err := s.Scan(&a.ID, &kind, &a.DisplayName, &createdAt, &updatedAt); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan account: %w", err)
}
a.Kind = AccountKind(kind)
a.CreatedAt = time.Unix(createdAt, 0).UTC()
a.UpdatedAt = time.Unix(updatedAt, 0).UTC()
return &a, nil
}
func scanAccounts(rows *sql.Rows) ([]*Account, error) {
var accounts []*Account
for rows.Next() {
a, err := scanAccount(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
func scanUser(s scanner) (*User, error) {
var u User
var createdAt int64
var lastLogin sql.NullInt64
if err := s.Scan(&u.ID, &u.Email, &createdAt, &lastLogin); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan user: %w", err)
}
u.CreatedAt = time.Unix(createdAt, 0).UTC()
if lastLogin.Valid {
ts := time.Unix(lastLogin.Int64, 0).UTC()
u.LastLoginAt = &ts
}
return &u, nil
}
func scanMembership(s scanner) (*AccountMembership, error) {
var m AccountMembership
var role string
var createdAt int64
if err := s.Scan(&m.AccountID, &m.UserID, &role, &createdAt); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan membership: %w", err)
}
m.Role = MemberRole(role)
m.CreatedAt = time.Unix(createdAt, 0).UTC()
return &m, nil
}
func scanMemberships(rows *sql.Rows) ([]*AccountMembership, error) {
var memberships []*AccountMembership
for rows.Next() {
m, err := scanMembership(rows)
if err != nil {
return nil, err
}
memberships = append(memberships, m)
}
return memberships, rows.Err()
}
func nullableTimeUnix(t *time.Time) any {
if t == nil {
return nil
@@ -286,9 +864,60 @@ func nullableTimeUnix(t *time.Time) any {
return t.Unix()
}
func nullableInt64Ptr(v *int64) any {
if v == nil {
return nil
}
return *v
}
func nullableString(s string) any {
if strings.TrimSpace(s) == "" {
return nil
}
return strings.TrimSpace(s)
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func scanStripeAccount(s scanner) (*StripeAccount, error) {
var sa StripeAccount
var subID, subItemID sql.NullString
var trialEnds, periodEnd sql.NullInt64
if err := s.Scan(
&sa.AccountID,
&sa.StripeCustomerID,
&subID,
&subItemID,
&sa.PlanVersion,
&sa.SubscriptionState,
&trialEnds,
&periodEnd,
&sa.UpdatedAt,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan stripe account: %w", err)
}
if subID.Valid {
sa.StripeSubscriptionID = subID.String
}
if subItemID.Valid {
sa.StripeSubItemWorkspacesID = subItemID.String
}
if trialEnds.Valid {
v := trialEnds.Int64
sa.TrialEndsAt = &v
}
if periodEnd.Valid {
v := periodEnd.Int64
sa.CurrentPeriodEnd = &v
}
return &sa, nil
}

View File

@@ -59,6 +59,32 @@ func TestGenerateTenantID_CrockfordCharset(t *testing.T) {
}
}
func TestGenerateAccountID(t *testing.T) {
id, err := GenerateAccountID()
if err != nil {
t.Fatalf("GenerateAccountID: %v", err)
}
if !strings.HasPrefix(id, "a_") {
t.Errorf("expected prefix a_, got %q", id)
}
if len(id) != 12 { // "a_" + 10 chars
t.Errorf("expected length 12, got %d (%q)", len(id), id)
}
}
func TestGenerateUserID(t *testing.T) {
id, err := GenerateUserID()
if err != nil {
t.Fatalf("GenerateUserID: %v", err)
}
if !strings.HasPrefix(id, "u_") {
t.Errorf("expected prefix u_, got %q", id)
}
if len(id) != 12 { // "u_" + 10 chars
t.Errorf("expected length 12, got %d (%q)", len(id), id)
}
}
func TestCRUD(t *testing.T) {
reg := newTestRegistry(t)
@@ -137,6 +163,226 @@ func TestCRUD(t *testing.T) {
}
}
func TestAccountCRUD(t *testing.T) {
reg := newTestRegistry(t)
accountID, err := GenerateAccountID()
if err != nil {
t.Fatal(err)
}
a := &Account{
ID: accountID,
Kind: AccountKindMSP,
DisplayName: "Test MSP",
}
// Create
if err := reg.CreateAccount(a); err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if a.CreatedAt.IsZero() {
t.Error("CreatedAt should be set")
}
if a.UpdatedAt.IsZero() {
t.Error("UpdatedAt should be set")
}
// Get
got, err := reg.GetAccount(accountID)
if err != nil {
t.Fatalf("GetAccount: %v", err)
}
if got == nil {
t.Fatal("GetAccount returned nil")
}
if got.Kind != AccountKindMSP {
t.Errorf("Kind = %q, want %q", got.Kind, AccountKindMSP)
}
if got.DisplayName != "Test MSP" {
t.Errorf("DisplayName = %q, want %q", got.DisplayName, "Test MSP")
}
// Update
got.DisplayName = "Renamed MSP"
if err := reg.UpdateAccount(got); err != nil {
t.Fatalf("UpdateAccount: %v", err)
}
got2, err := reg.GetAccount(accountID)
if err != nil {
t.Fatalf("GetAccount after update: %v", err)
}
if got2.DisplayName != "Renamed MSP" {
t.Errorf("DisplayName after update = %q, want %q", got2.DisplayName, "Renamed MSP")
}
// List
accounts, err := reg.ListAccounts()
if err != nil {
t.Fatalf("ListAccounts: %v", err)
}
if len(accounts) != 1 {
t.Fatalf("expected 1 account, got %d", len(accounts))
}
if accounts[0].ID != accountID {
t.Errorf("accounts[0].ID = %q, want %q", accounts[0].ID, accountID)
}
}
func TestUserCRUD(t *testing.T) {
reg := newTestRegistry(t)
userID, err := GenerateUserID()
if err != nil {
t.Fatal(err)
}
u := &User{
ID: userID,
Email: "user@example.com",
}
// Create
if err := reg.CreateUser(u); err != nil {
t.Fatalf("CreateUser: %v", err)
}
if u.CreatedAt.IsZero() {
t.Error("CreatedAt should be set")
}
// Get by ID
got, err := reg.GetUser(userID)
if err != nil {
t.Fatalf("GetUser: %v", err)
}
if got == nil {
t.Fatal("GetUser returned nil")
}
if got.Email != "user@example.com" {
t.Errorf("Email = %q, want %q", got.Email, "user@example.com")
}
if got.LastLoginAt != nil {
t.Errorf("LastLoginAt = %v, want nil", got.LastLoginAt)
}
// Get by email
got2, err := reg.GetUserByEmail("user@example.com")
if err != nil {
t.Fatalf("GetUserByEmail: %v", err)
}
if got2 == nil || got2.ID != userID {
t.Fatalf("GetUserByEmail returned %+v, want id=%q", got2, userID)
}
// Update last login
before := time.Now().UTC()
if err := reg.UpdateUserLastLogin(userID); err != nil {
t.Fatalf("UpdateUserLastLogin: %v", err)
}
after := time.Now().UTC()
got3, err := reg.GetUser(userID)
if err != nil {
t.Fatalf("GetUser after last login update: %v", err)
}
if got3.LastLoginAt == nil {
t.Fatal("LastLoginAt should be set")
}
ll := *got3.LastLoginAt
if ll.Before(before.Add(-2*time.Second)) || ll.After(after.Add(2*time.Second)) {
t.Fatalf("LastLoginAt=%s out of expected range [%s, %s]", ll, before, after)
}
}
func TestMembershipCRUD(t *testing.T) {
reg := newTestRegistry(t)
accountID, err := GenerateAccountID()
if err != nil {
t.Fatal(err)
}
userID, err := GenerateUserID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&Account{ID: accountID, Kind: AccountKindMSP, DisplayName: "Account"}); err != nil {
t.Fatal(err)
}
if err := reg.CreateUser(&User{ID: userID, Email: "member@example.com"}); err != nil {
t.Fatal(err)
}
m := &AccountMembership{
AccountID: accountID,
UserID: userID,
Role: MemberRoleOwner,
}
// Create
if err := reg.CreateMembership(m); err != nil {
t.Fatalf("CreateMembership: %v", err)
}
if m.CreatedAt.IsZero() {
t.Error("CreatedAt should be set")
}
// Get
got, err := reg.GetMembership(accountID, userID)
if err != nil {
t.Fatalf("GetMembership: %v", err)
}
if got == nil {
t.Fatal("GetMembership returned nil")
}
if got.Role != MemberRoleOwner {
t.Errorf("Role = %q, want %q", got.Role, MemberRoleOwner)
}
// List by account
members, err := reg.ListMembersByAccount(accountID)
if err != nil {
t.Fatalf("ListMembersByAccount: %v", err)
}
if len(members) != 1 {
t.Fatalf("expected 1 member, got %d", len(members))
}
if members[0].UserID != userID {
t.Errorf("members[0].UserID = %q, want %q", members[0].UserID, userID)
}
// List accounts by user
accounts, err := reg.ListAccountsByUser(userID)
if err != nil {
t.Fatalf("ListAccountsByUser: %v", err)
}
if len(accounts) != 1 || accounts[0] != accountID {
t.Fatalf("accounts=%v, want [%q]", accounts, accountID)
}
// Update role
if err := reg.UpdateMembershipRole(accountID, userID, MemberRoleAdmin); err != nil {
t.Fatalf("UpdateMembershipRole: %v", err)
}
got2, err := reg.GetMembership(accountID, userID)
if err != nil {
t.Fatalf("GetMembership after role update: %v", err)
}
if got2.Role != MemberRoleAdmin {
t.Errorf("Role after update = %q, want %q", got2.Role, MemberRoleAdmin)
}
// Delete
if err := reg.DeleteMembership(accountID, userID); err != nil {
t.Fatalf("DeleteMembership: %v", err)
}
got3, err := reg.GetMembership(accountID, userID)
if err != nil {
t.Fatalf("GetMembership after delete: %v", err)
}
if got3 != nil {
t.Fatalf("expected nil membership after delete, got %+v", got3)
}
}
func TestList(t *testing.T) {
reg := newTestRegistry(t)
@@ -199,6 +445,46 @@ func TestListByState(t *testing.T) {
}
}
func TestListByAccountID(t *testing.T) {
reg := newTestRegistry(t)
accountID, err := GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&Account{ID: accountID, Kind: AccountKindMSP, DisplayName: "Account"}); err != nil {
t.Fatal(err)
}
if err := reg.Create(&Tenant{ID: "t-ACCNT0001", AccountID: accountID, State: TenantStateActive}); err != nil {
t.Fatal(err)
}
if err := reg.Create(&Tenant{ID: "t-ACCNT0002", AccountID: accountID, State: TenantStateActive}); err != nil {
t.Fatal(err)
}
if err := reg.Create(&Tenant{ID: "t-ACCNT0003", State: TenantStateActive}); err != nil {
t.Fatal(err)
}
tenants, err := reg.ListByAccountID(accountID)
if err != nil {
t.Fatalf("ListByAccountID: %v", err)
}
if len(tenants) != 2 {
t.Fatalf("expected 2 tenants, got %d", len(tenants))
}
seen := make(map[string]bool)
for _, tnt := range tenants {
seen[tnt.ID] = true
if tnt.AccountID != accountID {
t.Errorf("tenant %s AccountID=%q, want %q", tnt.ID, tnt.AccountID, accountID)
}
}
if !seen["t-ACCNT0001"] || !seen["t-ACCNT0002"] {
t.Fatalf("expected tenants t-ACCNT0001 and t-ACCNT0002, got %+v", tenants)
}
}
func TestCountByState(t *testing.T) {
reg := newTestRegistry(t)
@@ -280,3 +566,107 @@ func TestNewTenantRegistry_InvalidDir(t *testing.T) {
}
}
}
func TestStripeAccountCRUD(t *testing.T) {
reg := newTestRegistry(t)
accountID, err := GenerateAccountID()
if err != nil {
t.Fatal(err)
}
if err := reg.CreateAccount(&Account{
ID: accountID,
Kind: AccountKindMSP,
DisplayName: "Test MSP",
}); err != nil {
t.Fatal(err)
}
trialEnds := time.Now().UTC().Add(7 * 24 * time.Hour).Unix()
periodEnd := time.Now().UTC().Add(30 * 24 * time.Hour).Unix()
sa := &StripeAccount{
AccountID: accountID,
StripeCustomerID: "cus_test_123",
StripeSubscriptionID: "sub_test_123",
StripeSubItemWorkspacesID: "si_workspaces_123",
PlanVersion: "msp_hosted_v1",
SubscriptionState: "trial",
TrialEndsAt: &trialEnds,
CurrentPeriodEnd: &periodEnd,
}
// Create
if err := reg.CreateStripeAccount(sa); err != nil {
t.Fatalf("CreateStripeAccount: %v", err)
}
// Get by account id
got, err := reg.GetStripeAccount(accountID)
if err != nil {
t.Fatalf("GetStripeAccount: %v", err)
}
if got == nil {
t.Fatal("GetStripeAccount returned nil")
}
if got.StripeCustomerID != "cus_test_123" {
t.Errorf("StripeCustomerID = %q, want %q", got.StripeCustomerID, "cus_test_123")
}
if got.StripeSubscriptionID != "sub_test_123" {
t.Errorf("StripeSubscriptionID = %q, want %q", got.StripeSubscriptionID, "sub_test_123")
}
// Get by customer id
got2, err := reg.GetStripeAccountByCustomerID("cus_test_123")
if err != nil {
t.Fatalf("GetStripeAccountByCustomerID: %v", err)
}
if got2 == nil || got2.AccountID != accountID {
t.Fatalf("expected accountID %q, got %#v", accountID, got2)
}
// Update
got2.SubscriptionState = "active"
got2.PlanVersion = "msp_hosted_v2"
got2.StripeSubscriptionID = "sub_test_456"
if err := reg.UpdateStripeAccount(got2); err != nil {
t.Fatalf("UpdateStripeAccount: %v", err)
}
got3, err := reg.GetStripeAccount(accountID)
if err != nil {
t.Fatalf("GetStripeAccount after update: %v", err)
}
if got3.SubscriptionState != "active" {
t.Errorf("SubscriptionState = %q, want %q", got3.SubscriptionState, "active")
}
if got3.PlanVersion != "msp_hosted_v2" {
t.Errorf("PlanVersion = %q, want %q", got3.PlanVersion, "msp_hosted_v2")
}
if got3.StripeSubscriptionID != "sub_test_456" {
t.Errorf("StripeSubscriptionID = %q, want %q", got3.StripeSubscriptionID, "sub_test_456")
}
if got3.UpdatedAt == 0 {
t.Error("UpdatedAt should be set")
}
}
func TestStripeEventIdempotency(t *testing.T) {
reg := newTestRegistry(t)
already, err := reg.RecordStripeEvent("evt_test_123", "customer.subscription.updated")
if err != nil {
t.Fatalf("RecordStripeEvent: %v", err)
}
if already {
t.Fatalf("expected alreadyProcessed=false on first insert")
}
already2, err := reg.RecordStripeEvent("evt_test_123", "customer.subscription.updated")
if err != nil {
t.Fatalf("RecordStripeEvent duplicate: %v", err)
}
if !already2 {
t.Fatalf("expected alreadyProcessed=true on duplicate insert")
}
}

View File

@@ -7,6 +7,8 @@ import (
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
cpauth "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/auth"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/docker"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/handoff"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/portal"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe"
)
@@ -99,4 +101,12 @@ func RegisterRoutes(mux *http.ServeMux, deps *Deps) {
mux.Handle("/api/accounts/{account_id}/tenants", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenantsCollection))
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenant))
// Tenant switching handoff (admin-key authenticated for now; session auth in M-4)
handoffHandler := handoff.HandleHandoff(deps.Registry, deps.Config.TenantsDir())
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}/handoff", admin.AdminKeyMiddleware(deps.Config.AdminKey, handoffHandler))
// MSP portal API (admin-key authenticated for now; session auth in M-4)
mux.Handle("/api/portal/dashboard", admin.AdminKeyMiddleware(deps.Config.AdminKey, portal.HandlePortalDashboard(deps.Registry)))
mux.Handle("/api/portal/workspaces/{tenant_id}", admin.AdminKeyMiddleware(deps.Config.AdminKey, portal.HandlePortalWorkspaceDetail(deps.Registry)))
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"strings"
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
"github.com/rs/zerolog/log"
stripelib "github.com/stripe/stripe-go/v82"
"github.com/stripe/stripe-go/v82/webhook"
@@ -98,14 +99,14 @@ func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) er
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
return fmt.Errorf("decode subscription: %w", err)
}
return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub)
return h.routeSubscriptionUpdated(r, sub)
case "customer.subscription.deleted":
var sub Subscription
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
return fmt.Errorf("decode subscription: %w", err)
}
return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub)
return h.routeSubscriptionDeleted(r, sub)
default:
log.Info().
@@ -116,6 +117,46 @@ func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) er
}
}
func (h *WebhookHandler) routeSubscriptionUpdated(r *http.Request, sub Subscription) error {
customerID := strings.TrimSpace(sub.Customer)
if customerID != "" {
sa, err := h.provisioner.registry.GetStripeAccountByCustomerID(customerID)
if err != nil {
return fmt.Errorf("lookup stripe account by customer: %w", err)
}
if sa != nil {
acct, err := h.provisioner.registry.GetAccount(sa.AccountID)
if err != nil {
return fmt.Errorf("lookup account: %w", err)
}
if acct != nil && acct.Kind == registry.AccountKindMSP {
return h.provisioner.HandleMSPSubscriptionUpdated(r.Context(), sub)
}
}
}
return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub)
}
func (h *WebhookHandler) routeSubscriptionDeleted(r *http.Request, sub Subscription) error {
customerID := strings.TrimSpace(sub.Customer)
if customerID != "" {
sa, err := h.provisioner.registry.GetStripeAccountByCustomerID(customerID)
if err != nil {
return fmt.Errorf("lookup stripe account by customer: %w", err)
}
if sa != nil {
acct, err := h.provisioner.registry.GetAccount(sa.AccountID)
if err != nil {
return fmt.Errorf("lookup account: %w", err)
}
if acct != nil && acct.Kind == registry.AccountKindMSP {
return h.provisioner.HandleMSPSubscriptionDeleted(r.Context(), sub)
}
}
}
return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub)
}
// CheckoutSession is a minimal representation of a Stripe checkout.session event.
type CheckoutSession struct {
ID string `json:"id"`