mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
feat(cloud): implement Hosted MSP account layer (M-1 through M-6)
Add the full Hosted MSP portal backend to the control plane. MSP accounts own multiple container-per-tenant workspaces with consolidated billing, account-level RBAC, and signed JWT handoff for tenant switching. - M-1: Account model + schema (accounts, users, account_memberships tables, tenant.account_id FK, ID generators, full CRUD) - M-2: Account-level RBAC handlers (invite/update/remove members, last-owner protection) - M-3: Portal dashboard API (workspace list with fleet health summaries, workspace detail) - M-4: Tenant handoff auth (JWT HS256 minting with per-tenant secret, auto-submit HTML POST binding, tenant-side exchange endpoint with JTI replay protection) - M-5: Workspace provisioning from portal (create/list/update/delete workspaces, handoff.key generation, billing.json, Docker lifecycle) - M-6: Consolidated billing (stripe_accounts table, MSP-aware webhook routing, subscription state propagation across account tenants)
This commit is contained in:
1
go.mod
1
go.mod
@@ -59,6 +59,7 @@ require (
|
||||
github.com/go-openapi/jsonreference v0.20.2 // indirect
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/gnostic-models v0.6.8 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -76,6 +76,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=
|
||||
|
||||
233
internal/api/cloud_handoff_handlers.go
Normal file
233
internal/api/cloud_handoff_handlers.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
cloudHandoffIssuer = "pulse-cloud-control-plane"
|
||||
)
|
||||
|
||||
type cloudHandoffClaims struct {
|
||||
AccountID string `json:"account_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type jtiReplayStore struct {
|
||||
once sync.Once
|
||||
db *sql.DB
|
||||
mu sync.Mutex
|
||||
|
||||
configDir string
|
||||
initErr error
|
||||
}
|
||||
|
||||
func (s *jtiReplayStore) init() {
|
||||
s.once.Do(func() {
|
||||
dir := filepath.Clean(s.configDir)
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
s.initErr = fmt.Errorf("configDir is required")
|
||||
return
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
s.initErr = fmt.Errorf("create config dir: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(dir, "handoff_jti.db")
|
||||
dsn := dbPath + "?" + url.Values{
|
||||
"_pragma": []string{
|
||||
"busy_timeout(30000)",
|
||||
"journal_mode(WAL)",
|
||||
"synchronous(NORMAL)",
|
||||
},
|
||||
}.Encode()
|
||||
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
s.initErr = fmt.Errorf("open handoff jti db: %w", err)
|
||||
return
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS handoff_jti (
|
||||
jti TEXT PRIMARY KEY,
|
||||
expires_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_handoff_jti_expires_at ON handoff_jti(expires_at);
|
||||
`
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
_ = db.Close()
|
||||
s.initErr = fmt.Errorf("init handoff jti schema: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.db = db
|
||||
})
|
||||
}
|
||||
|
||||
func (s *jtiReplayStore) checkAndStore(jti string, expiresAt time.Time) (stored bool, err error) {
|
||||
s.init()
|
||||
if s.initErr != nil {
|
||||
return false, s.initErr
|
||||
}
|
||||
if s.db == nil {
|
||||
return false, fmt.Errorf("handoff jti store not initialized")
|
||||
}
|
||||
jti = strings.TrimSpace(jti)
|
||||
if jti == "" {
|
||||
return false, fmt.Errorf("jti is required")
|
||||
}
|
||||
expiresAt = expiresAt.UTC()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
if _, err := s.db.Exec(`DELETE FROM handoff_jti WHERE expires_at <= ?`, now); err != nil {
|
||||
return false, fmt.Errorf("cleanup handoff jti: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`INSERT INTO handoff_jti (jti, expires_at) VALUES (?, ?)`, jti, expiresAt.Unix())
|
||||
if err != nil {
|
||||
if isSQLiteUniqueViolation(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("store handoff jti: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isSQLiteUniqueViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := err.Error()
|
||||
return strings.Contains(s, "UNIQUE constraint failed") || strings.Contains(s, "constraint failed")
|
||||
}
|
||||
|
||||
func tenantIDFromRequest(r *http.Request) string {
|
||||
if v := strings.TrimSpace(os.Getenv("PULSE_TENANT_ID")); v != "" {
|
||||
return v
|
||||
}
|
||||
host := strings.TrimSpace(r.Host)
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
// Host is expected to be "<tenant-id>.<baseDomain>".
|
||||
if i := strings.IndexByte(host, '.'); i > 0 {
|
||||
return host[:i]
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// HandleHandoffExchange verifies a control-plane-minted handoff JWT and records its jti to prevent replay.
|
||||
//
|
||||
// Route (wiring happens elsewhere): POST /api/cloud/handoff/exchange
|
||||
//
|
||||
// This minimal implementation returns success JSON containing user info derived from the token.
|
||||
// Wiring into RBAC/session minting is intentionally deferred.
|
||||
func HandleHandoffExchange(configDir string) http.HandlerFunc {
|
||||
configDir = filepath.Clean(configDir)
|
||||
keyPath := filepath.Join(configDir, "secrets", "handoff.key")
|
||||
replay := &jtiReplayStore{configDir: configDir}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := tenantIDFromRequest(r)
|
||||
if tenantID == "" {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
tokenStr := strings.TrimSpace(r.FormValue("token"))
|
||||
if tokenStr == "" {
|
||||
http.Error(w, "missing token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var claims cloudHandoffClaims
|
||||
parsed, err := jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
&claims,
|
||||
func(t *jwt.Token) (any, error) { return key, nil },
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(cloudHandoffIssuer),
|
||||
jwt.WithAudience(tenantID),
|
||||
)
|
||||
if err != nil || parsed == nil || !parsed.Valid {
|
||||
http.Error(w, "invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.ExpiresAt == nil {
|
||||
http.Error(w, "invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := replay.checkAndStore(claims.ID, claims.ExpiresAt.Time)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
http.Error(w, "replayed token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"ok": true,
|
||||
"tenant_id": tenantID,
|
||||
"account_id": claims.AccountID,
|
||||
"user_id": claims.Subject,
|
||||
"email": claims.Email,
|
||||
"role": claims.Role,
|
||||
"jti": claims.ID,
|
||||
"exp": claims.ExpiresAt.Time.UTC().Format(time.RFC3339),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
346
internal/cloudcp/account/handlers.go
Normal file
346
internal/cloudcp/account/handlers.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
type memberResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role registry.MemberRole `json:"role"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// HandleListMembers returns an authenticated handler that lists all members of an account.
|
||||
func HandleListMembers(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
if accountID == "" {
|
||||
http.Error(w, "missing account_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
memberships, err := reg.ListMembersByAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if memberships == nil {
|
||||
memberships = []*registry.AccountMembership{}
|
||||
}
|
||||
|
||||
resp := make([]memberResponse, 0, len(memberships))
|
||||
for _, m := range memberships {
|
||||
u, err := reg.GetUser(m.UserID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if u == nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp = append(resp, memberResponse{
|
||||
UserID: m.UserID,
|
||||
Email: u.Email,
|
||||
Role: m.Role,
|
||||
CreatedAt: m.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
|
||||
type inviteMemberRequest struct {
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// HandleInviteMember returns an authenticated handler that invites a user to an account.
|
||||
func HandleInviteMember(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
if accountID == "" {
|
||||
http.Error(w, "missing account_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var req inviteMemberRequest
|
||||
if err := decodeJSON(w, r, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
email := normalizeEmail(req.Email)
|
||||
if email == "" {
|
||||
http.Error(w, "invalid email", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
role, ok := parseMemberRole(req.Role)
|
||||
if !ok {
|
||||
http.Error(w, "invalid role", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := reg.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if u == nil {
|
||||
userID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
u = ®istry.User{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
}
|
||||
if err := reg.CreateUser(u); err != nil {
|
||||
// If a concurrent request created the user, fall back to lookup.
|
||||
u2, gerr := reg.GetUserByEmail(email)
|
||||
if gerr != nil || u2 == nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
u = u2
|
||||
}
|
||||
}
|
||||
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{
|
||||
AccountID: accountID,
|
||||
UserID: u.ID,
|
||||
Role: role,
|
||||
}); err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
http.Error(w, "membership already exists", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
}
|
||||
|
||||
type updateMemberRoleRequest struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// HandleUpdateMemberRole returns an authenticated handler that updates a member's role.
|
||||
func HandleUpdateMemberRole(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPatch {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
userID := strings.TrimSpace(r.PathValue("user_id"))
|
||||
if accountID == "" || userID == "" {
|
||||
http.Error(w, "missing account_id or user_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var req updateMemberRoleRequest
|
||||
if err := decodeJSON(w, r, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
role, ok := parseMemberRole(req.Role)
|
||||
if !ok {
|
||||
http.Error(w, "invalid role", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := reg.UpdateMembershipRole(accountID, userID, role); err != nil {
|
||||
if isNotFoundErr(err) {
|
||||
http.Error(w, "membership not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleRemoveMember returns an authenticated handler that removes a user from an account.
|
||||
func HandleRemoveMember(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
userID := strings.TrimSpace(r.PathValue("user_id"))
|
||||
if accountID == "" || userID == "" {
|
||||
http.Error(w, "missing account_id or user_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
http.Error(w, "membership not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if m.Role == registry.MemberRoleOwner {
|
||||
memberships, err := reg.ListMembersByAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
owners := 0
|
||||
for _, mm := range memberships {
|
||||
if mm.Role == registry.MemberRoleOwner {
|
||||
owners++
|
||||
}
|
||||
}
|
||||
if owners <= 1 {
|
||||
http.Error(w, "cannot remove last owner", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := reg.DeleteMembership(accountID, userID); err != nil {
|
||||
if isNotFoundErr(err) {
|
||||
http.Error(w, "membership not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeEmail(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.ToLower(s)
|
||||
// Minimal sanity; deeper validation comes later with session auth flows.
|
||||
if s == "" || !strings.Contains(s, "@") {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func parseMemberRole(s string) (registry.MemberRole, bool) {
|
||||
switch registry.MemberRole(strings.TrimSpace(s)) {
|
||||
case registry.MemberRoleOwner:
|
||||
return registry.MemberRoleOwner, true
|
||||
case registry.MemberRoleAdmin:
|
||||
return registry.MemberRoleAdmin, true
|
||||
case registry.MemberRoleTech:
|
||||
return registry.MemberRoleTech, true
|
||||
case registry.MemberRoleReadOnly:
|
||||
return registry.MemberRoleReadOnly, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSON(w http.ResponseWriter, r *http.Request, dst any) error {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
if err := dec.Decode(dst); err != nil {
|
||||
http.Error(w, "invalid JSON body", http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
if err := dec.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
http.Error(w, "invalid JSON body", http.StatusBadRequest)
|
||||
return errors.New("multiple JSON values")
|
||||
}
|
||||
http.Error(w, "invalid JSON body", http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNotFoundErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Registry uses fmt.Errorf("... not found") (no sentinel errors yet).
|
||||
return strings.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func isUniqueViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// modernc.org/sqlite returns strings containing "UNIQUE constraint failed".
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "unique constraint failed")
|
||||
}
|
||||
343
internal/cloudcp/account/handlers_test.go
Normal file
343
internal/cloudcp/account/handlers_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
reg, err := registry.NewTenantRegistry(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTenantRegistry: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = reg.Close() })
|
||||
return reg
|
||||
}
|
||||
|
||||
func newTestMux(reg *registry.TenantRegistry) *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
listMembers := HandleListMembers(reg)
|
||||
inviteMember := HandleInviteMember(reg)
|
||||
updateRole := HandleUpdateMemberRole(reg)
|
||||
removeMember := HandleRemoveMember(reg)
|
||||
|
||||
collection := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
listMembers(w, r)
|
||||
case http.MethodPost:
|
||||
inviteMember(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
member := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPatch:
|
||||
updateRole(w, r)
|
||||
case http.MethodDelete:
|
||||
removeMember(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
mux.Handle("/api/accounts/{account_id}/members", admin.AdminKeyMiddleware("secret-key", collection))
|
||||
mux.Handle("/api/accounts/{account_id}/members/{user_id}", admin.AdminKeyMiddleware("secret-key", member))
|
||||
return mux
|
||||
}
|
||||
|
||||
func doRequest(t *testing.T, h http.Handler, req *http.Request) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
req.Header.Set("X-Admin-Key", "secret-key")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
func TestInviteMember(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := `{"email":"tech@msp.com","role":"tech"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/members", bytes.NewBufferString(body))
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String())
|
||||
}
|
||||
|
||||
u, err := reg.GetUserByEmail("tech@msp.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if u == nil {
|
||||
t.Fatal("expected user to be created")
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m == nil {
|
||||
t.Fatal("expected membership to be created")
|
||||
}
|
||||
if m.Role != registry.MemberRoleTech {
|
||||
t.Fatalf("role = %q, want %q", m.Role, registry.MemberRoleTech)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteExistingUser(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: userID, Email: "existing@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := `{"email":"existing@msp.com","role":"tech"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/members", bytes.NewBufferString(body))
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String())
|
||||
}
|
||||
|
||||
u, err := reg.GetUserByEmail("existing@msp.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if u == nil || u.ID != userID {
|
||||
t.Fatalf("user = %+v, want id=%q", u, userID)
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m == nil {
|
||||
t.Fatal("expected membership to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListMembers(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u1ID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
u2ID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: u1ID, Email: "owner@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: u2ID, Email: "tech@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: u1ID, Role: registry.MemberRoleOwner}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: u2ID, Role: registry.MemberRoleTech}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/accounts/"+accountID+"/members", nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var got []struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role registry.MemberRole `json:"role"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 members, got %d (%+v)", len(got), got)
|
||||
}
|
||||
|
||||
sort.Slice(got, func(i, j int) bool { return got[i].Email < got[j].Email })
|
||||
if got[0].Email != "owner@msp.com" || got[0].Role != registry.MemberRoleOwner {
|
||||
t.Fatalf("member[0]=%+v, want owner@msp.com owner", got[0])
|
||||
}
|
||||
if got[1].Email != "tech@msp.com" || got[1].Role != registry.MemberRoleTech {
|
||||
t.Fatalf("member[1]=%+v, want tech@msp.com tech", got[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMemberRole(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: userID, Email: "tech@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: userID, Role: registry.MemberRoleTech}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := `{"role":"admin"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/accounts/"+accountID+"/members/"+userID, bytes.NewBufferString(body))
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m == nil {
|
||||
t.Fatal("expected membership to exist")
|
||||
}
|
||||
if m.Role != registry.MemberRoleAdmin {
|
||||
t.Fatalf("role = %q, want %q", m.Role, registry.MemberRoleAdmin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveMember(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ownerID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
techID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: ownerID, Email: "owner@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: techID, Email: "tech@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: ownerID, Role: registry.MemberRoleOwner}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: techID, Role: registry.MemberRoleTech}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/members/"+techID, nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusNoContent, rec.Body.String())
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, techID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m != nil {
|
||||
t.Fatalf("expected membership to be deleted, got %+v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotRemoveLastOwner(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ownerID, err := registry.GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: ownerID, Email: "owner@msp.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: ownerID, Role: registry.MemberRoleOwner}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/members/"+ownerID, nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusConflict && rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d or %d (body=%q)", rec.Code, http.StatusConflict, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, ownerID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m == nil {
|
||||
t.Fatal("expected owner membership to remain")
|
||||
}
|
||||
}
|
||||
279
internal/cloudcp/account/tenant_handlers.go
Normal file
279
internal/cloudcp/account/tenant_handlers.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
// WorkspaceProvisioner is the minimal interface needed by the MSP portal tenant handlers.
|
||||
// Implemented by internal/cloudcp/stripe.Provisioner.
|
||||
type WorkspaceProvisioner interface {
|
||||
ProvisionWorkspace(ctx context.Context, accountID, displayName string) (*registry.Tenant, error)
|
||||
DeprovisionWorkspaceContainer(ctx context.Context, tenant *registry.Tenant) error
|
||||
}
|
||||
|
||||
// HandleListTenants lists all tenants for an account.
|
||||
// Route: GET /api/accounts/{account_id}/tenants
|
||||
func HandleListTenants(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
if accountID == "" {
|
||||
http.Error(w, "missing account_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
tenants, err := reg.ListByAccountID(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if tenants == nil {
|
||||
tenants = []*registry.Tenant{}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(tenants)
|
||||
}
|
||||
}
|
||||
|
||||
type createTenantRequest struct {
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// HandleCreateTenant creates a new tenant under an account.
|
||||
// Route: POST /api/accounts/{account_id}/tenants
|
||||
func HandleCreateTenant(reg *registry.TenantRegistry, provisioner WorkspaceProvisioner) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
if accountID == "" {
|
||||
http.Error(w, "missing account_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var req createTenantRequest
|
||||
if err := decodeJSON(w, r, &req); err != nil {
|
||||
return
|
||||
}
|
||||
displayName := strings.TrimSpace(req.DisplayName)
|
||||
if displayName == "" {
|
||||
http.Error(w, "invalid display_name", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if provisioner == nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tenant, err := provisioner.ProvisionWorkspace(r.Context(), accountID, displayName)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(tenant)
|
||||
}
|
||||
}
|
||||
|
||||
type updateTenantRequest struct {
|
||||
DisplayName *string `json:"display_name,omitempty"`
|
||||
Status *string `json:"status,omitempty"`
|
||||
State *string `json:"state,omitempty"`
|
||||
}
|
||||
|
||||
func parseTenantState(s string) (registry.TenantState, bool) {
|
||||
switch registry.TenantState(strings.TrimSpace(s)) {
|
||||
case registry.TenantStateActive:
|
||||
return registry.TenantStateActive, true
|
||||
case registry.TenantStateSuspended:
|
||||
return registry.TenantStateSuspended, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func loadTenantForAccount(reg *registry.TenantRegistry, accountID, tenantID string) (*registry.Tenant, error) {
|
||||
t, err := reg.Get(tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
|
||||
return nil, nil
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// HandleUpdateTenant updates display name and/or state.
|
||||
// Route: PATCH /api/accounts/{account_id}/tenants/{tenant_id}
|
||||
func HandleUpdateTenant(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPatch {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
|
||||
if accountID == "" || tenantID == "" {
|
||||
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
tenant, err := loadTenantForAccount(reg, accountID, tenantID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if tenant == nil {
|
||||
http.Error(w, "tenant not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var req updateTenantRequest
|
||||
if err := decodeJSON(w, r, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if req.DisplayName != nil {
|
||||
name := strings.TrimSpace(*req.DisplayName)
|
||||
if name == "" {
|
||||
http.Error(w, "invalid display_name", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
tenant.DisplayName = name
|
||||
}
|
||||
|
||||
stateVal := req.Status
|
||||
if stateVal == nil {
|
||||
stateVal = req.State
|
||||
}
|
||||
if stateVal != nil {
|
||||
st, ok := parseTenantState(*stateVal)
|
||||
if !ok {
|
||||
http.Error(w, "invalid status", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
tenant.State = st
|
||||
}
|
||||
|
||||
if err := reg.Update(tenant); err != nil {
|
||||
if isNotFoundErr(err) {
|
||||
http.Error(w, "tenant not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(tenant)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDeleteTenant soft-deletes a tenant and deprovisions its container if Docker is available.
|
||||
// Route: DELETE /api/accounts/{account_id}/tenants/{tenant_id}
|
||||
func HandleDeleteTenant(reg *registry.TenantRegistry, provisioner WorkspaceProvisioner) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
|
||||
if accountID == "" || tenantID == "" {
|
||||
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
tenant, err := loadTenantForAccount(reg, accountID, tenantID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if tenant == nil {
|
||||
http.Error(w, "tenant not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
tenant.State = registry.TenantStateDeleting
|
||||
if err := reg.Update(tenant); err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if provisioner != nil {
|
||||
if err := provisioner.DeprovisionWorkspaceContainer(r.Context(), tenant); err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tenant.ContainerID = ""
|
||||
tenant.State = registry.TenantStateDeleted
|
||||
if err := reg.Update(tenant); err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
220
internal/cloudcp/account/tenant_handlers_test.go
Normal file
220
internal/cloudcp/account/tenant_handlers_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe"
|
||||
)
|
||||
|
||||
func newTestTenantMux(reg *registry.TenantRegistry, tenantsDir string) (*http.ServeMux, *cpstripe.Provisioner) {
|
||||
mux := http.NewServeMux()
|
||||
provisioner := cpstripe.NewProvisioner(reg, tenantsDir, nil, nil, "https://cloud.example.com")
|
||||
|
||||
listTenants := HandleListTenants(reg)
|
||||
createTenant := HandleCreateTenant(reg, provisioner)
|
||||
updateTenant := HandleUpdateTenant(reg)
|
||||
deleteTenant := HandleDeleteTenant(reg, provisioner)
|
||||
|
||||
collection := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
listTenants(w, r)
|
||||
case http.MethodPost:
|
||||
createTenant(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
tenant := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPatch:
|
||||
updateTenant(w, r)
|
||||
case http.MethodDelete:
|
||||
deleteTenant(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
mux.Handle("/api/accounts/{account_id}/tenants", admin.AdminKeyMiddleware("secret-key", collection))
|
||||
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}", admin.AdminKeyMiddleware("secret-key", tenant))
|
||||
return mux, provisioner
|
||||
}
|
||||
|
||||
func TestCreateWorkspace(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
tenantsDir := t.TempDir()
|
||||
mux, _ := newTestTenantMux(reg, tenantsDir)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := `{"display_name":"Acme Dental"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/tenants", bytes.NewBufferString(body))
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusCreated, rec.Body.String())
|
||||
}
|
||||
|
||||
var got registry.Tenant
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if got.AccountID != accountID {
|
||||
t.Fatalf("account_id = %q, want %q", got.AccountID, accountID)
|
||||
}
|
||||
if got.DisplayName != "Acme Dental" {
|
||||
t.Fatalf("display_name = %q, want %q", got.DisplayName, "Acme Dental")
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(tenantsDir, got.ID, "secrets", "handoff.key")
|
||||
info, err := os.Stat(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("handoff.key missing: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0o600 {
|
||||
t.Fatalf("handoff.key perms = %o, want %o", info.Mode().Perm(), 0o600)
|
||||
}
|
||||
b, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read handoff.key: %v", err)
|
||||
}
|
||||
if len(b) != 32 {
|
||||
t.Fatalf("handoff.key size = %d, want 32", len(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListWorkspaces(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
tenantsDir := t.TempDir()
|
||||
mux, provisioner := newTestTenantMux(reg, tenantsDir)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t1, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client One")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t2, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client Two")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/accounts/"+accountID+"/tenants", nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var got []*registry.Tenant
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 tenants, got %d (%+v)", len(got), got)
|
||||
}
|
||||
|
||||
ids := map[string]bool{}
|
||||
for _, tt := range got {
|
||||
if tt.AccountID != accountID {
|
||||
t.Fatalf("tenant account_id = %q, want %q", tt.AccountID, accountID)
|
||||
}
|
||||
ids[tt.ID] = true
|
||||
}
|
||||
if !ids[t1.ID] || !ids[t2.ID] {
|
||||
t.Fatalf("missing ids: got=%v want=%q,%q", ids, t1.ID, t2.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteWorkspace(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
tenantsDir := t.TempDir()
|
||||
mux, provisioner := newTestTenantMux(reg, tenantsDir)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test MSP"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tenant, err := provisioner.ProvisionWorkspace(context.Background(), accountID, "Client")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/accounts/"+accountID+"/tenants/"+tenant.ID, nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusNoContent, rec.Body.String())
|
||||
}
|
||||
|
||||
t2, err := reg.Get(tenant.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if t2 == nil {
|
||||
t.Fatal("expected tenant to exist")
|
||||
}
|
||||
if t2.State != registry.TenantStateDeleted {
|
||||
t.Fatalf("state = %q, want %q", t2.State, registry.TenantStateDeleted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantBelongsToAccount(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
tenantsDir := t.TempDir()
|
||||
mux, provisioner := newTestTenantMux(reg, tenantsDir)
|
||||
|
||||
account1, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account2, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: account1, Kind: registry.AccountKindMSP, DisplayName: "A1"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: account2, Kind: registry.AccountKindMSP, DisplayName: "A2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tenant, err := provisioner.ProvisionWorkspace(context.Background(), account1, "Client")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := `{"display_name":"New Name"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/accounts/"+account2+"/tenants/"+tenant.ID, bytes.NewBufferString(body))
|
||||
rec := doRequest(t, mux, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound && rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want 404/403 (body=%q)", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
165
internal/cloudcp/handoff/handler.go
Normal file
165
internal/cloudcp/handoff/handler.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package handoff
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
handoffKeyRelPath = "secrets/handoff.key"
|
||||
)
|
||||
|
||||
var handoffHTMLTemplate = template.Must(template.New("handoff").Parse(`<!DOCTYPE html>
|
||||
<html><body>
|
||||
<form method="POST" action="https://{{.TenantID}}.{{.BaseDomain}}/api/cloud/handoff/exchange">
|
||||
<input type="hidden" name="token" value="{{.Token}}" />
|
||||
</form>
|
||||
<script>document.forms[0].submit()</script>
|
||||
</body></html>
|
||||
`))
|
||||
|
||||
type handoffHTMLData struct {
|
||||
TenantID string
|
||||
BaseDomain string
|
||||
Token string
|
||||
GeneratedAt time.Time
|
||||
}
|
||||
|
||||
// HandleHandoff mints a tenant handoff token and returns an auto-submit HTML page.
|
||||
// Route (wiring happens elsewhere): POST /api/accounts/{account_id}/tenants/{tenant_id}/handoff
|
||||
//
|
||||
// Auth: admin-key for now. User identity is provided by X-User-ID (temporary; session auth replaces this).
|
||||
func HandleHandoff(reg *registry.TenantRegistry, tenantsDir string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if reg == nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(r.PathValue("account_id"))
|
||||
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
|
||||
if accountID == "" || tenantID == "" {
|
||||
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
t, err := reg.Get(tenantID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if t == nil || strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
|
||||
http.Error(w, "tenant not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Temporary request-scoped identity until control plane session auth exists.
|
||||
userID := strings.TrimSpace(r.Header.Get("X-User-ID"))
|
||||
if userID == "" {
|
||||
userID = strings.TrimSpace(r.Header.Get("X-User-Id"))
|
||||
}
|
||||
if userID == "" {
|
||||
http.Error(w, "missing user identity", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
m, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := reg.GetUser(userID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if u == nil || strings.TrimSpace(u.Email) == "" {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(filepath.Clean(tenantsDir), tenantID, filepath.FromSlash(handoffKeyRelPath))
|
||||
secret, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
token, err := MintHandoffToken(secret, HandoffClaims{
|
||||
TenantID: tenantID,
|
||||
UserID: userID,
|
||||
AccountID: accountID,
|
||||
Email: u.Email,
|
||||
Role: m.Role,
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(defaultTTL),
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
baseDomain := deriveBaseDomain(r)
|
||||
if baseDomain == "" {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = handoffHTMLTemplate.Execute(w, handoffHTMLData{
|
||||
TenantID: tenantID,
|
||||
BaseDomain: baseDomain,
|
||||
Token: token,
|
||||
GeneratedAt: now,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func deriveBaseDomain(r *http.Request) string {
|
||||
if v := strings.TrimSpace(os.Getenv("CP_BASE_DOMAIN")); v != "" {
|
||||
return v
|
||||
}
|
||||
host := strings.TrimSpace(r.Host)
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
// Strip port if present.
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
} else {
|
||||
// net.SplitHostPort errors if there's no port; that's fine.
|
||||
if strings.Count(host, ":") > 1 {
|
||||
// IPv6 without port isn't valid for our use case.
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
116
internal/cloudcp/handoff/handler_test.go
Normal file
116
internal/cloudcp/handoff/handler_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package handoff
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
reg, err := registry.NewTenantRegistry(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTenantRegistry: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = reg.Close() })
|
||||
return reg
|
||||
}
|
||||
|
||||
func TestHandoffHandler(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
tenantsDir := t.TempDir()
|
||||
|
||||
accountID := "a_TEST"
|
||||
tenantID := "t-TEST"
|
||||
userID := "u_TEST"
|
||||
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(®istry.User{ID: userID, Email: "tech@example.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateMembership(®istry.AccountMembership{AccountID: accountID, UserID: userID, Role: registry.MemberRoleTech}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(®istry.Tenant{ID: tenantID, AccountID: accountID, DisplayName: "Client", State: registry.TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
keyPath := filepath.Join(tenantsDir, tenantID, "secrets", "handoff.key")
|
||||
if err := os.MkdirAll(filepath.Dir(keyPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, secret, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h := HandleHandoff(reg, tenantsDir)
|
||||
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}/handoff", admin.AdminKeyMiddleware("secret-key", h))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/accounts/"+accountID+"/tenants/"+tenantID+"/handoff", nil)
|
||||
req.Host = "cloud.example.com"
|
||||
req.Header.Set("X-Admin-Key", "secret-key")
|
||||
req.Header.Set("X-User-ID", userID)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
body := rec.Body.String()
|
||||
wantAction := "https://" + tenantID + ".cloud.example.com/api/cloud/handoff/exchange"
|
||||
if !regexp.MustCompile(regexp.QuoteMeta(wantAction)).MatchString(body) {
|
||||
t.Fatalf("missing form action %q in body", wantAction)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`name="token" value="([^"]+)"`)
|
||||
m := re.FindStringSubmatch(body)
|
||||
if len(m) != 2 {
|
||||
t.Fatalf("failed to extract token from HTML")
|
||||
}
|
||||
tokenStr := m[1]
|
||||
|
||||
var got jwtHandoffClaims
|
||||
parsed, err := jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
&got,
|
||||
func(t *jwt.Token) (any, error) { return secret, nil },
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithAudience(tenantID),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseWithClaims: %v", err)
|
||||
}
|
||||
if !parsed.Valid {
|
||||
t.Fatalf("token valid = false, want true")
|
||||
}
|
||||
if got.Subject != userID {
|
||||
t.Fatalf("sub = %q, want %q", got.Subject, userID)
|
||||
}
|
||||
if got.AccountID != accountID {
|
||||
t.Fatalf("account_id = %q, want %q", got.AccountID, accountID)
|
||||
}
|
||||
if got.Email != "tech@example.com" {
|
||||
t.Fatalf("email = %q, want %q", got.Email, "tech@example.com")
|
||||
}
|
||||
if got.Role != registry.MemberRoleTech {
|
||||
t.Fatalf("role = %q, want %q", got.Role, registry.MemberRoleTech)
|
||||
}
|
||||
if got.ExpiresAt == nil || time.Until(got.ExpiresAt.Time) > 60*time.Second+2*time.Second {
|
||||
t.Fatalf("exp looks wrong: %v", got.ExpiresAt)
|
||||
}
|
||||
}
|
||||
131
internal/cloudcp/handoff/handoff.go
Normal file
131
internal/cloudcp/handoff/handoff.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package handoff
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
issuer = "pulse-cloud-control-plane"
|
||||
defaultTTL = 60 * time.Second
|
||||
handoffKeyLen = 32
|
||||
)
|
||||
|
||||
// HandoffClaims are the logical claims that the control plane wants to assert
|
||||
// when handing off an authenticated account user into a tenant container.
|
||||
type HandoffClaims struct {
|
||||
TenantID string
|
||||
UserID string
|
||||
AccountID string
|
||||
Email string
|
||||
Role registry.MemberRole
|
||||
|
||||
IssuedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
JTI string
|
||||
}
|
||||
|
||||
type jwtHandoffClaims struct {
|
||||
AccountID string `json:"account_id"`
|
||||
Email string `json:"email"`
|
||||
Role registry.MemberRole `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateHandoffKey returns 32 cryptographically random bytes suitable for HS256 signing.
|
||||
// Intended for writing to /data/tenants/<tenant_id>/secrets/handoff.key (0600).
|
||||
func GenerateHandoffKey() ([]byte, error) {
|
||||
key := make([]byte, handoffKeyLen)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return nil, fmt.Errorf("generate handoff key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// MintHandoffToken mints a short-lived HS256 JWT signed with the per-tenant handoff.key.
|
||||
//
|
||||
// JWT registered claims:
|
||||
// - iss: pulse-cloud-control-plane
|
||||
// - aud: <tenant-id>
|
||||
// - sub: <user-id>
|
||||
// - iat, exp, jti
|
||||
//
|
||||
// Custom claims:
|
||||
// - account_id, email, role
|
||||
func MintHandoffToken(secret []byte, claims HandoffClaims) (string, error) {
|
||||
if len(secret) == 0 {
|
||||
return "", fmt.Errorf("secret is required")
|
||||
}
|
||||
claims.TenantID = sanitizeID(claims.TenantID)
|
||||
claims.UserID = sanitizeID(claims.UserID)
|
||||
claims.AccountID = sanitizeID(claims.AccountID)
|
||||
if claims.TenantID == "" || claims.UserID == "" || claims.AccountID == "" {
|
||||
return "", fmt.Errorf("tenantID, userID, and accountID are required")
|
||||
}
|
||||
if claims.Email == "" {
|
||||
return "", fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if claims.IssuedAt.IsZero() {
|
||||
claims.IssuedAt = now
|
||||
}
|
||||
claims.IssuedAt = claims.IssuedAt.UTC()
|
||||
|
||||
if claims.ExpiresAt.IsZero() {
|
||||
claims.ExpiresAt = claims.IssuedAt.Add(defaultTTL)
|
||||
}
|
||||
claims.ExpiresAt = claims.ExpiresAt.UTC()
|
||||
if !claims.ExpiresAt.After(claims.IssuedAt) {
|
||||
return "", fmt.Errorf("expiresAt must be after issuedAt")
|
||||
}
|
||||
|
||||
if claims.JTI == "" {
|
||||
jti, err := randomJTI128()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.JTI = jti
|
||||
}
|
||||
|
||||
jc := jwtHandoffClaims{
|
||||
AccountID: claims.AccountID,
|
||||
Email: claims.Email,
|
||||
Role: claims.Role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: issuer,
|
||||
Subject: claims.UserID,
|
||||
Audience: jwt.ClaimStrings{claims.TenantID},
|
||||
IssuedAt: jwt.NewNumericDate(claims.IssuedAt),
|
||||
ExpiresAt: jwt.NewNumericDate(claims.ExpiresAt),
|
||||
ID: claims.JTI,
|
||||
},
|
||||
}
|
||||
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jc)
|
||||
signed, err := tok.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign jwt: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func randomJTI128() (string, error) {
|
||||
b := make([]byte, 16) // 128-bit
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
return "", fmt.Errorf("generate jti: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func sanitizeID(s string) string {
|
||||
// IDs are generated by the registry; trimming prevents surprising mismatches.
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
167
internal/cloudcp/handoff/handoff_test.go
Normal file
167
internal/cloudcp/handoff/handoff_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package handoff
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
func TestMintAndVerify(t *testing.T) {
|
||||
secret := []byte("0123456789abcdef0123456789abcdef") // 32 bytes
|
||||
now := time.Now().UTC()
|
||||
|
||||
tokenStr, err := MintHandoffToken(secret, HandoffClaims{
|
||||
TenantID: "t-TESTTENANT",
|
||||
UserID: "u_TESTUSER",
|
||||
AccountID: "a_TESTACCOUNT",
|
||||
Email: "tech@example.com",
|
||||
Role: registry.MemberRoleTech,
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(60 * time.Second),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("MintHandoffToken: %v", err)
|
||||
}
|
||||
|
||||
var got jwtHandoffClaims
|
||||
parsed, err := jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
&got,
|
||||
func(t *jwt.Token) (any, error) { return secret, nil },
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithAudience("t-TESTTENANT"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseWithClaims: %v", err)
|
||||
}
|
||||
if !parsed.Valid {
|
||||
t.Fatalf("token valid = false, want true")
|
||||
}
|
||||
|
||||
if got.Subject != "u_TESTUSER" {
|
||||
t.Fatalf("sub = %q, want %q", got.Subject, "u_TESTUSER")
|
||||
}
|
||||
if got.AccountID != "a_TESTACCOUNT" {
|
||||
t.Fatalf("account_id = %q, want %q", got.AccountID, "a_TESTACCOUNT")
|
||||
}
|
||||
if got.Email != "tech@example.com" {
|
||||
t.Fatalf("email = %q, want %q", got.Email, "tech@example.com")
|
||||
}
|
||||
if got.Role != registry.MemberRoleTech {
|
||||
t.Fatalf("role = %q, want %q", got.Role, registry.MemberRoleTech)
|
||||
}
|
||||
if got.ID == "" {
|
||||
t.Fatalf("jti empty")
|
||||
}
|
||||
if got.IssuedAt == nil || got.ExpiresAt == nil {
|
||||
t.Fatalf("missing iat/exp")
|
||||
}
|
||||
if got.ExpiresAt.Time.Sub(got.IssuedAt.Time) != 60*time.Second {
|
||||
t.Fatalf("exp-iat = %v, want %v", got.ExpiresAt.Time.Sub(got.IssuedAt.Time), 60*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiredToken(t *testing.T) {
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
past := time.Now().UTC().Add(-2 * time.Minute)
|
||||
|
||||
tokenStr, err := MintHandoffToken(secret, HandoffClaims{
|
||||
TenantID: "t-EXPIRED",
|
||||
UserID: "u_USER",
|
||||
AccountID: "a_ACCOUNT",
|
||||
Email: "x@example.com",
|
||||
Role: registry.MemberRoleReadOnly,
|
||||
IssuedAt: past,
|
||||
ExpiresAt: past.Add(30 * time.Second),
|
||||
JTI: "deadbeefdeadbeefdeadbeefdeadbeef",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("MintHandoffToken: %v", err)
|
||||
}
|
||||
|
||||
var got jwtHandoffClaims
|
||||
_, err = jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
&got,
|
||||
func(t *jwt.Token) (any, error) { return secret, nil },
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithAudience("t-EXPIRED"),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("expected expired token error")
|
||||
}
|
||||
if !errors.Is(err, jwt.ErrTokenExpired) {
|
||||
// Leave room for wrapped validation errors.
|
||||
t.Fatalf("error = %v, want ErrTokenExpired", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongSecret(t *testing.T) {
|
||||
secretA := []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
|
||||
secretB := []byte("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
|
||||
now := time.Now().UTC()
|
||||
|
||||
tokenStr, err := MintHandoffToken(secretA, HandoffClaims{
|
||||
TenantID: "t-TENANT",
|
||||
UserID: "u_USER",
|
||||
AccountID: "a_ACCOUNT",
|
||||
Email: "x@example.com",
|
||||
Role: registry.MemberRoleAdmin,
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(60 * time.Second),
|
||||
JTI: "0123456789abcdef0123456789abcdef",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("MintHandoffToken: %v", err)
|
||||
}
|
||||
|
||||
var got jwtHandoffClaims
|
||||
_, err = jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
&got,
|
||||
func(t *jwt.Token) (any, error) { return secretB, nil },
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithAudience("t-TENANT"),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("expected signature verification failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongAudience(t *testing.T) {
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
now := time.Now().UTC()
|
||||
|
||||
tokenStr, err := MintHandoffToken(secret, HandoffClaims{
|
||||
TenantID: "t-A",
|
||||
UserID: "u_USER",
|
||||
AccountID: "a_ACCOUNT",
|
||||
Email: "x@example.com",
|
||||
Role: registry.MemberRoleOwner,
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(60 * time.Second),
|
||||
JTI: "0123456789abcdef0123456789abcdef",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("MintHandoffToken: %v", err)
|
||||
}
|
||||
|
||||
var got jwtHandoffClaims
|
||||
_, err = jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
&got,
|
||||
func(t *jwt.Token) (any, error) { return secret, nil },
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithAudience("t-B"),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("expected wrong audience error")
|
||||
}
|
||||
}
|
||||
202
internal/cloudcp/portal/handlers.go
Normal file
202
internal/cloudcp/portal/handlers.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package portal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
type accountInfo struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Kind registry.AccountKind `json:"kind"`
|
||||
}
|
||||
|
||||
type workspaceSummaryItem struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
State registry.TenantState `json:"state"`
|
||||
HealthCheckOK bool `json:"health_check_ok"`
|
||||
LastHealthCheck *time.Time `json:"last_health_check"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type dashboardSummary struct {
|
||||
Total int `json:"total"`
|
||||
Active int `json:"active"`
|
||||
Healthy int `json:"healthy"`
|
||||
Unhealthy int `json:"unhealthy"`
|
||||
Suspended int `json:"suspended"`
|
||||
}
|
||||
|
||||
type dashboardResponse struct {
|
||||
Account accountInfo `json:"account"`
|
||||
Workspaces []workspaceSummaryItem `json:"workspaces"`
|
||||
Summary dashboardSummary `json:"summary"`
|
||||
}
|
||||
|
||||
type workspaceDetailResponse struct {
|
||||
Account accountInfo `json:"account"`
|
||||
Workspace *registry.Tenant `json:"workspace"`
|
||||
}
|
||||
|
||||
func accountIDFromRequest(r *http.Request) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
if v := strings.TrimSpace(r.URL.Query().Get("account_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
// Convenience for future callers; spec says query param is fine for now.
|
||||
if v := strings.TrimSpace(r.Header.Get("X-Account-ID")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(r.Header.Get("X-Account-Id")); v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// HandlePortalDashboard returns a portal-oriented dashboard response for an account.
|
||||
// Route: GET /api/portal/dashboard?account_id=...
|
||||
//
|
||||
// Auth: admin-key for now (session auth in M-4).
|
||||
func HandlePortalDashboard(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if reg == nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := accountIDFromRequest(r)
|
||||
if accountID == "" {
|
||||
http.Error(w, "missing account_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
tenants, err := reg.ListByAccountID(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if tenants == nil {
|
||||
tenants = []*registry.Tenant{}
|
||||
}
|
||||
|
||||
resp := dashboardResponse{
|
||||
Account: accountInfo{
|
||||
ID: a.ID,
|
||||
DisplayName: a.DisplayName,
|
||||
Kind: a.Kind,
|
||||
},
|
||||
Workspaces: make([]workspaceSummaryItem, 0, len(tenants)),
|
||||
}
|
||||
|
||||
for _, t := range tenants {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resp.Workspaces = append(resp.Workspaces, workspaceSummaryItem{
|
||||
ID: t.ID,
|
||||
DisplayName: t.DisplayName,
|
||||
State: t.State,
|
||||
HealthCheckOK: t.HealthCheckOK,
|
||||
LastHealthCheck: t.LastHealthCheck,
|
||||
CreatedAt: t.CreatedAt,
|
||||
})
|
||||
|
||||
resp.Summary.Total++
|
||||
|
||||
switch t.State {
|
||||
case registry.TenantStateActive:
|
||||
resp.Summary.Active++
|
||||
if t.HealthCheckOK {
|
||||
resp.Summary.Healthy++
|
||||
} else {
|
||||
resp.Summary.Unhealthy++
|
||||
}
|
||||
case registry.TenantStateSuspended:
|
||||
resp.Summary.Suspended++
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlePortalWorkspaceDetail returns a portal-oriented detail response for a single tenant.
|
||||
// Route: GET /api/portal/workspaces/{tenant_id}?account_id=...
|
||||
//
|
||||
// Auth: admin-key for now (session auth in M-4).
|
||||
func HandlePortalWorkspaceDetail(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if reg == nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := accountIDFromRequest(r)
|
||||
tenantID := strings.TrimSpace(r.PathValue("tenant_id"))
|
||||
if accountID == "" || tenantID == "" {
|
||||
http.Error(w, "missing account_id or tenant_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if a == nil {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
t, err := reg.Get(tenantID)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if t == nil || strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
|
||||
http.Error(w, "tenant not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
resp := workspaceDetailResponse{
|
||||
Account: accountInfo{
|
||||
ID: a.ID,
|
||||
DisplayName: a.DisplayName,
|
||||
Kind: a.Kind,
|
||||
},
|
||||
Workspace: t,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
310
internal/cloudcp/portal/handlers_test.go
Normal file
310
internal/cloudcp/portal/handlers_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package portal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
reg, err := registry.NewTenantRegistry(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTenantRegistry: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = reg.Close() })
|
||||
return reg
|
||||
}
|
||||
|
||||
func newTestMux(reg *registry.TenantRegistry) *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/api/portal/dashboard", admin.AdminKeyMiddleware("secret-key", HandlePortalDashboard(reg)))
|
||||
mux.Handle("/api/portal/workspaces/{tenant_id}", admin.AdminKeyMiddleware("secret-key", HandlePortalWorkspaceDetail(reg)))
|
||||
return mux
|
||||
}
|
||||
|
||||
func doRequest(t *testing.T, h http.Handler, req *http.Request) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
req.Header.Set("X-Admin-Key", "secret-key")
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
type dashboardResp struct {
|
||||
Account struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Kind registry.AccountKind `json:"kind"`
|
||||
} `json:"account"`
|
||||
Workspaces []struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
State registry.TenantState `json:"state"`
|
||||
HealthCheckOK bool `json:"health_check_ok"`
|
||||
LastHealthCheck *time.Time `json:"last_health_check"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
} `json:"workspaces"`
|
||||
Summary struct {
|
||||
Total int `json:"total"`
|
||||
Active int `json:"active"`
|
||||
Healthy int `json:"healthy"`
|
||||
Unhealthy int `json:"unhealthy"`
|
||||
Suspended int `json:"suspended"`
|
||||
} `json:"summary"`
|
||||
}
|
||||
|
||||
func TestPortalDashboard(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Example MSP"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tenantActiveID, err := registry.GenerateTenantID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tenantSuspendedID, err := registry.GenerateTenantID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
created1 := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC)
|
||||
created2 := time.Date(2026, 2, 10, 11, 0, 0, 0, time.UTC)
|
||||
lastCheck := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
if err := reg.Create(®istry.Tenant{
|
||||
ID: tenantActiveID,
|
||||
AccountID: accountID,
|
||||
DisplayName: "Acme Dental",
|
||||
State: registry.TenantStateActive,
|
||||
CreatedAt: created1,
|
||||
LastHealthCheck: &lastCheck,
|
||||
HealthCheckOK: true,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(®istry.Tenant{
|
||||
ID: tenantSuspendedID,
|
||||
AccountID: accountID,
|
||||
DisplayName: "Suspended Workspace",
|
||||
State: registry.TenantStateSuspended,
|
||||
CreatedAt: created2,
|
||||
HealthCheckOK: false,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/portal/dashboard?account_id="+accountID, nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp dashboardResp
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v (body=%q)", err, rec.Body.String())
|
||||
}
|
||||
|
||||
if resp.Account.ID != accountID {
|
||||
t.Fatalf("account.id = %q, want %q", resp.Account.ID, accountID)
|
||||
}
|
||||
if resp.Account.DisplayName != "Example MSP" {
|
||||
t.Fatalf("account.display_name = %q, want %q", resp.Account.DisplayName, "Example MSP")
|
||||
}
|
||||
if resp.Account.Kind != registry.AccountKindMSP {
|
||||
t.Fatalf("account.kind = %q, want %q", resp.Account.Kind, registry.AccountKindMSP)
|
||||
}
|
||||
|
||||
if len(resp.Workspaces) != 2 {
|
||||
t.Fatalf("workspaces len = %d, want %d", len(resp.Workspaces), 2)
|
||||
}
|
||||
|
||||
// Make assertions order-independent.
|
||||
sort.Slice(resp.Workspaces, func(i, j int) bool { return resp.Workspaces[i].ID < resp.Workspaces[j].ID })
|
||||
wsByID := map[string]dashboardRespWorkspace{}
|
||||
|
||||
// local helper type for easier indexing
|
||||
for _, ws := range resp.Workspaces {
|
||||
wsByID[ws.ID] = dashboardRespWorkspace{
|
||||
ID: ws.ID,
|
||||
DisplayName: ws.DisplayName,
|
||||
State: ws.State,
|
||||
HealthCheckOK: ws.HealthCheckOK,
|
||||
LastHealthCheck: ws.LastHealthCheck,
|
||||
CreatedAt: ws.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
active := wsByID[tenantActiveID]
|
||||
if active.ID == "" {
|
||||
t.Fatalf("missing active workspace id %q", tenantActiveID)
|
||||
}
|
||||
if active.DisplayName != "Acme Dental" {
|
||||
t.Fatalf("active.display_name = %q, want %q", active.DisplayName, "Acme Dental")
|
||||
}
|
||||
if active.State != registry.TenantStateActive {
|
||||
t.Fatalf("active.state = %q, want %q", active.State, registry.TenantStateActive)
|
||||
}
|
||||
if !active.HealthCheckOK {
|
||||
t.Fatalf("active.health_check_ok = false, want true")
|
||||
}
|
||||
if active.LastHealthCheck == nil || !active.LastHealthCheck.Equal(lastCheck) {
|
||||
t.Fatalf("active.last_health_check = %v, want %v", active.LastHealthCheck, lastCheck)
|
||||
}
|
||||
if !active.CreatedAt.Equal(created1) {
|
||||
t.Fatalf("active.created_at = %v, want %v", active.CreatedAt, created1)
|
||||
}
|
||||
|
||||
susp := wsByID[tenantSuspendedID]
|
||||
if susp.ID == "" {
|
||||
t.Fatalf("missing suspended workspace id %q", tenantSuspendedID)
|
||||
}
|
||||
if susp.State != registry.TenantStateSuspended {
|
||||
t.Fatalf("suspended.state = %q, want %q", susp.State, registry.TenantStateSuspended)
|
||||
}
|
||||
|
||||
if resp.Summary.Total != 2 {
|
||||
t.Fatalf("summary.total = %d, want %d", resp.Summary.Total, 2)
|
||||
}
|
||||
if resp.Summary.Active != 1 {
|
||||
t.Fatalf("summary.active = %d, want %d", resp.Summary.Active, 1)
|
||||
}
|
||||
if resp.Summary.Healthy != 1 {
|
||||
t.Fatalf("summary.healthy = %d, want %d", resp.Summary.Healthy, 1)
|
||||
}
|
||||
if resp.Summary.Unhealthy != 0 {
|
||||
t.Fatalf("summary.unhealthy = %d, want %d", resp.Summary.Unhealthy, 0)
|
||||
}
|
||||
if resp.Summary.Suspended != 1 {
|
||||
t.Fatalf("summary.suspended = %d, want %d", resp.Summary.Suspended, 1)
|
||||
}
|
||||
}
|
||||
|
||||
type dashboardRespWorkspace struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
State registry.TenantState
|
||||
HealthCheckOK bool
|
||||
LastHealthCheck *time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func TestPortalDashboardEmpty(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Empty MSP"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/portal/dashboard?account_id="+accountID, nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp dashboardResp
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v (body=%q)", err, rec.Body.String())
|
||||
}
|
||||
|
||||
if len(resp.Workspaces) != 0 {
|
||||
t.Fatalf("workspaces len = %d, want %d", len(resp.Workspaces), 0)
|
||||
}
|
||||
if resp.Summary.Total != 0 || resp.Summary.Active != 0 || resp.Summary.Healthy != 0 || resp.Summary.Unhealthy != 0 || resp.Summary.Suspended != 0 {
|
||||
t.Fatalf("summary = %+v, want all zeros", resp.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortalWorkspaceDetail(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
mux := newTestMux(reg)
|
||||
|
||||
accountID, err := registry.GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(®istry.Account{ID: accountID, Kind: registry.AccountKindMSP, DisplayName: "Example MSP"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tenantID, err := registry.GenerateTenantID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
created := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC)
|
||||
lastCheck := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
|
||||
if err := reg.Create(®istry.Tenant{
|
||||
ID: tenantID,
|
||||
AccountID: accountID,
|
||||
DisplayName: "Acme Dental",
|
||||
State: registry.TenantStateActive,
|
||||
CreatedAt: created,
|
||||
LastHealthCheck: &lastCheck,
|
||||
HealthCheckOK: true,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/portal/workspaces/"+tenantID+"?account_id="+accountID, nil)
|
||||
rec := doRequest(t, mux, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Account struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Kind registry.AccountKind `json:"kind"`
|
||||
} `json:"account"`
|
||||
Workspace registry.Tenant `json:"workspace"`
|
||||
}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v (body=%q)", err, rec.Body.String())
|
||||
}
|
||||
|
||||
if resp.Account.ID != accountID {
|
||||
t.Fatalf("account.id = %q, want %q", resp.Account.ID, accountID)
|
||||
}
|
||||
if resp.Workspace.ID != tenantID {
|
||||
t.Fatalf("workspace.id = %q, want %q", resp.Workspace.ID, tenantID)
|
||||
}
|
||||
if resp.Workspace.AccountID != accountID {
|
||||
t.Fatalf("workspace.account_id = %q, want %q", resp.Workspace.AccountID, accountID)
|
||||
}
|
||||
if resp.Workspace.DisplayName != "Acme Dental" {
|
||||
t.Fatalf("workspace.display_name = %q, want %q", resp.Workspace.DisplayName, "Acme Dental")
|
||||
}
|
||||
if resp.Workspace.State != registry.TenantStateActive {
|
||||
t.Fatalf("workspace.state = %q, want %q", resp.Workspace.State, registry.TenantStateActive)
|
||||
}
|
||||
if !resp.Workspace.HealthCheckOK {
|
||||
t.Fatalf("workspace.health_check_ok = false, want true")
|
||||
}
|
||||
if resp.Workspace.LastHealthCheck == nil || !resp.Workspace.LastHealthCheck.Equal(lastCheck) {
|
||||
t.Fatalf("workspace.last_health_check = %v, want %v", resp.Workspace.LastHealthCheck, lastCheck)
|
||||
}
|
||||
if !resp.Workspace.CreatedAt.Equal(created) {
|
||||
t.Fatalf("workspace.created_at = %v, want %v", resp.Workspace.CreatedAt, created)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,47 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type AccountKind string
|
||||
|
||||
const (
|
||||
AccountKindIndividual AccountKind = "individual"
|
||||
AccountKindMSP AccountKind = "msp"
|
||||
)
|
||||
|
||||
// Account represents an account record in the control plane registry.
|
||||
type Account struct {
|
||||
ID string `json:"id"`
|
||||
Kind AccountKind `json:"kind"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// User represents a user record in the control plane registry.
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
|
||||
}
|
||||
|
||||
type MemberRole string
|
||||
|
||||
const (
|
||||
MemberRoleOwner MemberRole = "owner"
|
||||
MemberRoleAdmin MemberRole = "admin"
|
||||
MemberRoleTech MemberRole = "tech"
|
||||
MemberRoleReadOnly MemberRole = "read_only"
|
||||
)
|
||||
|
||||
// AccountMembership represents a mapping between an account and a user.
|
||||
type AccountMembership struct {
|
||||
AccountID string `json:"account_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Role MemberRole `json:"role"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// TenantState represents the lifecycle state of a tenant.
|
||||
type TenantState string
|
||||
|
||||
@@ -15,12 +56,14 @@ const (
|
||||
TenantStateActive TenantState = "active"
|
||||
TenantStateSuspended TenantState = "suspended"
|
||||
TenantStateCanceled TenantState = "canceled"
|
||||
TenantStateDeleting TenantState = "deleting"
|
||||
TenantStateDeleted TenantState = "deleted"
|
||||
)
|
||||
|
||||
// Tenant represents a Cloud tenant record in the registry.
|
||||
type Tenant struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"account_id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display_name"`
|
||||
State TenantState `json:"state"`
|
||||
@@ -37,6 +80,20 @@ type Tenant struct {
|
||||
HealthCheckOK bool `json:"health_check_ok"`
|
||||
}
|
||||
|
||||
// StripeAccount maps a control-plane account to a single Stripe customer +
|
||||
// subscription for consolidated (MSP-style) billing.
|
||||
type StripeAccount struct {
|
||||
AccountID string `json:"account_id"`
|
||||
StripeCustomerID string `json:"stripe_customer_id"`
|
||||
StripeSubscriptionID string `json:"stripe_subscription_id"`
|
||||
StripeSubItemWorkspacesID string `json:"stripe_sub_item_workspaces_id"`
|
||||
PlanVersion string `json:"plan_version"`
|
||||
SubscriptionState string `json:"subscription_state"` // trial, active, past_due, canceled
|
||||
TrialEndsAt *int64 `json:"trial_ends_at"`
|
||||
CurrentPeriodEnd *int64 `json:"current_period_end"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// crockfordBase32 is the Crockford base32 alphabet (excludes I, L, O, U).
|
||||
const crockfordBase32 = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||
|
||||
@@ -54,3 +111,33 @@ func GenerateTenantID() (string, error) {
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// GenerateAccountID returns an account ID of the form "a_" followed by 10 random
|
||||
// Crockford base32 characters (50 bits of entropy).
|
||||
func GenerateAccountID() (string, error) {
|
||||
b := make([]byte, 10)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate account id: %w", err)
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString("a_")
|
||||
for _, v := range b {
|
||||
sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)])
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// GenerateUserID returns a user ID of the form "u_" followed by 10 random
|
||||
// Crockford base32 characters (50 bits of entropy).
|
||||
func GenerateUserID() (string, error) {
|
||||
b := make([]byte, 10)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate user id: %w", err)
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString("u_")
|
||||
for _, v := range b {
|
||||
sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)])
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
@@ -51,6 +52,7 @@ func (r *TenantRegistry) initSchema() error {
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS tenants (
|
||||
id TEXT PRIMARY KEY,
|
||||
account_id TEXT NOT NULL DEFAULT '',
|
||||
email TEXT NOT NULL DEFAULT '',
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT 'provisioning',
|
||||
@@ -68,13 +70,106 @@ func (r *TenantRegistry) initSchema() error {
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_tenants_state ON tenants(state);
|
||||
CREATE INDEX IF NOT EXISTS idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id TEXT PRIMARY KEY,
|
||||
kind TEXT NOT NULL DEFAULT 'individual',
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS stripe_accounts (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
stripe_customer_id TEXT NOT NULL UNIQUE,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_sub_item_workspaces_id TEXT,
|
||||
plan_version TEXT NOT NULL DEFAULT '',
|
||||
subscription_state TEXT NOT NULL DEFAULT 'trial',
|
||||
trial_ends_at INTEGER,
|
||||
current_period_end INTEGER,
|
||||
updated_at INTEGER NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_stripe_accounts_customer ON stripe_accounts(stripe_customer_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS stripe_events (
|
||||
stripe_event_id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
received_at INTEGER NOT NULL,
|
||||
processed_at INTEGER,
|
||||
processing_error TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_login_at INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS account_memberships (
|
||||
account_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'tech',
|
||||
created_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (account_id, user_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_memberships_user_id ON account_memberships(user_id);
|
||||
`
|
||||
if _, err := r.db.Exec(schema); err != nil {
|
||||
return fmt.Errorf("init tenant registry schema: %w", err)
|
||||
}
|
||||
|
||||
// Migration: add account_id to tenants if not present.
|
||||
// (SQLite makes it awkward to add FK constraints via ALTER TABLE, and FK
|
||||
// enforcement is off by default; this keeps the change backwards-compatible.)
|
||||
hasAccountID, err := r.tenantsHasColumn("account_id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasAccountID {
|
||||
if _, err := r.db.Exec(`ALTER TABLE tenants ADD COLUMN account_id TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return fmt.Errorf("migrate tenants: add account_id: %w", err)
|
||||
}
|
||||
}
|
||||
if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_tenants_account_id ON tenants(account_id)`); err != nil {
|
||||
return fmt.Errorf("init tenant registry schema: create idx_tenants_account_id: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TenantRegistry) tenantsHasColumn(name string) (bool, error) {
|
||||
rows, err := r.db.Query(`PRAGMA table_info(tenants)`)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("pragma table_info(tenants): %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
cid int
|
||||
colName string
|
||||
colType string
|
||||
notNull int
|
||||
dflt sql.NullString
|
||||
pk int
|
||||
)
|
||||
if err := rows.Scan(&cid, &colName, &colType, ¬Null, &dflt, &pk); err != nil {
|
||||
return false, fmt.Errorf("scan table_info(tenants): %w", err)
|
||||
}
|
||||
if colName == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("iterate table_info(tenants): %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Ping checks database connectivity (used for readiness probes).
|
||||
func (r *TenantRegistry) Ping() error {
|
||||
return r.db.Ping()
|
||||
@@ -101,12 +196,12 @@ func (r *TenantRegistry) Create(t *Tenant) error {
|
||||
|
||||
_, err := r.db.Exec(`
|
||||
INSERT INTO tenants (
|
||||
id, email, display_name, state,
|
||||
id, account_id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
t.ID, t.Email, t.DisplayName, string(t.State),
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
t.ID, t.AccountID, t.Email, t.DisplayName, string(t.State),
|
||||
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
|
||||
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
|
||||
t.CreatedAt.Unix(), t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
|
||||
@@ -120,7 +215,7 @@ func (r *TenantRegistry) Create(t *Tenant) error {
|
||||
// Get retrieves a tenant by ID.
|
||||
func (r *TenantRegistry) Get(id string) (*Tenant, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, email, display_name, state,
|
||||
id, account_id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
@@ -131,7 +226,7 @@ func (r *TenantRegistry) Get(id string) (*Tenant, error) {
|
||||
// GetByStripeCustomerID retrieves a tenant by Stripe customer ID.
|
||||
func (r *TenantRegistry) GetByStripeCustomerID(customerID string) (*Tenant, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, email, display_name, state,
|
||||
id, account_id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
@@ -148,12 +243,12 @@ func (r *TenantRegistry) Update(t *Tenant) error {
|
||||
|
||||
res, err := r.db.Exec(`
|
||||
UPDATE tenants SET
|
||||
email = ?, display_name = ?, state = ?,
|
||||
account_id = ?, email = ?, display_name = ?, state = ?,
|
||||
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_price_id = ?,
|
||||
plan_version = ?, container_id = ?, current_image_digest = ?, desired_image_digest = ?,
|
||||
updated_at = ?, last_health_check = ?, health_check_ok = ?
|
||||
WHERE id = ?`,
|
||||
t.Email, t.DisplayName, string(t.State),
|
||||
t.AccountID, t.Email, t.DisplayName, string(t.State),
|
||||
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
|
||||
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
|
||||
t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
|
||||
@@ -172,7 +267,7 @@ func (r *TenantRegistry) Update(t *Tenant) error {
|
||||
// List returns all tenants.
|
||||
func (r *TenantRegistry) List() ([]*Tenant, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
id, email, display_name, state,
|
||||
id, account_id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
@@ -187,7 +282,7 @@ func (r *TenantRegistry) List() ([]*Tenant, error) {
|
||||
// ListByState returns all tenants matching the given state.
|
||||
func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
id, email, display_name, state,
|
||||
id, account_id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
@@ -199,6 +294,21 @@ func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
|
||||
return scanTenants(rows)
|
||||
}
|
||||
|
||||
// ListByAccountID returns all tenants belonging to the given account ID.
|
||||
func (r *TenantRegistry) ListByAccountID(accountID string) ([]*Tenant, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
id, account_id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
FROM tenants WHERE account_id = ? ORDER BY created_at DESC`, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tenants by account id: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanTenants(rows)
|
||||
}
|
||||
|
||||
// CountByState returns a map of state -> count.
|
||||
func (r *TenantRegistry) CountByState() (map[TenantState]int, error) {
|
||||
rows, err := r.db.Query(`SELECT state, COUNT(*) FROM tenants GROUP BY state`)
|
||||
@@ -244,7 +354,7 @@ func scanTenant(s scanner) (*Tenant, error) {
|
||||
var healthOK int
|
||||
|
||||
err := s.Scan(
|
||||
&t.ID, &t.Email, &t.DisplayName, &state,
|
||||
&t.ID, &t.AccountID, &t.Email, &t.DisplayName, &state,
|
||||
&t.StripeCustomerID, &t.StripeSubscriptionID, &t.StripePriceID,
|
||||
&t.PlanVersion, &t.ContainerID, &t.CurrentImageDigest, &t.DesiredImageDigest,
|
||||
&createdAt, &updatedAt, &lastHealthCheck, &healthOK,
|
||||
@@ -279,6 +389,474 @@ func scanTenants(rows *sql.Rows) ([]*Tenant, error) {
|
||||
return tenants, rows.Err()
|
||||
}
|
||||
|
||||
// CreateAccount inserts a new account record.
|
||||
func (r *TenantRegistry) CreateAccount(a *Account) error {
|
||||
if a == nil {
|
||||
return fmt.Errorf("account is nil")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if a.CreatedAt.IsZero() {
|
||||
a.CreatedAt = now
|
||||
}
|
||||
a.UpdatedAt = now
|
||||
|
||||
kind := string(a.Kind)
|
||||
if kind == "" {
|
||||
kind = string(AccountKindIndividual)
|
||||
}
|
||||
|
||||
_, err := r.db.Exec(`
|
||||
INSERT INTO accounts (
|
||||
id, kind, display_name, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?)`,
|
||||
a.ID, kind, a.DisplayName, a.CreatedAt.Unix(), a.UpdatedAt.Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create account: %w", err)
|
||||
}
|
||||
a.Kind = AccountKind(kind)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccount retrieves an account by ID.
|
||||
func (r *TenantRegistry) GetAccount(id string) (*Account, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, kind, display_name, created_at, updated_at
|
||||
FROM accounts WHERE id = ?`, id)
|
||||
return scanAccount(row)
|
||||
}
|
||||
|
||||
// UpdateAccount modifies an existing account record.
|
||||
func (r *TenantRegistry) UpdateAccount(a *Account) error {
|
||||
if a == nil {
|
||||
return fmt.Errorf("account is nil")
|
||||
}
|
||||
a.UpdatedAt = time.Now().UTC()
|
||||
|
||||
kind := string(a.Kind)
|
||||
if kind == "" {
|
||||
kind = string(AccountKindIndividual)
|
||||
}
|
||||
|
||||
res, err := r.db.Exec(`
|
||||
UPDATE accounts SET
|
||||
kind = ?, display_name = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
kind, a.DisplayName, a.UpdatedAt.Unix(),
|
||||
a.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("account %q not found", a.ID)
|
||||
}
|
||||
a.Kind = AccountKind(kind)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAccounts returns all accounts.
|
||||
func (r *TenantRegistry) ListAccounts() ([]*Account, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
id, kind, display_name, created_at, updated_at
|
||||
FROM accounts ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanAccounts(rows)
|
||||
}
|
||||
|
||||
// CreateUser inserts a new user record.
|
||||
func (r *TenantRegistry) CreateUser(u *User) error {
|
||||
if u == nil {
|
||||
return fmt.Errorf("user is nil")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if u.CreatedAt.IsZero() {
|
||||
u.CreatedAt = now
|
||||
}
|
||||
|
||||
_, err := r.db.Exec(`
|
||||
INSERT INTO users (
|
||||
id, email, created_at, last_login_at
|
||||
) VALUES (?, ?, ?, ?)`,
|
||||
u.ID, u.Email, u.CreatedAt.Unix(), nullableTimeUnix(u.LastLoginAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUser retrieves a user by ID.
|
||||
func (r *TenantRegistry) GetUser(id string) (*User, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, email, created_at, last_login_at
|
||||
FROM users WHERE id = ?`, id)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email.
|
||||
func (r *TenantRegistry) GetUserByEmail(email string) (*User, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, email, created_at, last_login_at
|
||||
FROM users WHERE email = ?`, email)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// UpdateUserLastLogin sets last_login_at for the given user ID to the current time.
|
||||
func (r *TenantRegistry) UpdateUserLastLogin(id string) error {
|
||||
now := time.Now().UTC()
|
||||
res, err := r.db.Exec(`UPDATE users SET last_login_at = ? WHERE id = ?`, now.Unix(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update user last login: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("user %q not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMembership inserts a new membership record.
|
||||
func (r *TenantRegistry) CreateMembership(m *AccountMembership) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("membership is nil")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if m.CreatedAt.IsZero() {
|
||||
m.CreatedAt = now
|
||||
}
|
||||
role := string(m.Role)
|
||||
if role == "" {
|
||||
role = string(MemberRoleTech)
|
||||
}
|
||||
|
||||
_, err := r.db.Exec(`
|
||||
INSERT INTO account_memberships (
|
||||
account_id, user_id, role, created_at
|
||||
) VALUES (?, ?, ?, ?)`,
|
||||
m.AccountID, m.UserID, role, m.CreatedAt.Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create membership: %w", err)
|
||||
}
|
||||
m.Role = MemberRole(role)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMembership retrieves a membership record by account ID and user ID.
|
||||
func (r *TenantRegistry) GetMembership(accountID, userID string) (*AccountMembership, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
account_id, user_id, role, created_at
|
||||
FROM account_memberships
|
||||
WHERE account_id = ? AND user_id = ?`, accountID, userID)
|
||||
return scanMembership(row)
|
||||
}
|
||||
|
||||
// ListMembersByAccount returns all membership records for a given account ID.
|
||||
func (r *TenantRegistry) ListMembersByAccount(accountID string) ([]*AccountMembership, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
account_id, user_id, role, created_at
|
||||
FROM account_memberships
|
||||
WHERE account_id = ?
|
||||
ORDER BY created_at DESC`, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list members by account: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMemberships(rows)
|
||||
}
|
||||
|
||||
// ListAccountsByUser returns account IDs for all accounts the given user belongs to.
|
||||
func (r *TenantRegistry) ListAccountsByUser(userID string) ([]string, error) {
|
||||
rows, err := r.db.Query(`SELECT account_id FROM account_memberships WHERE user_id = ? ORDER BY created_at DESC`, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by user: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accountIDs []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan account id: %w", err)
|
||||
}
|
||||
accountIDs = append(accountIDs, id)
|
||||
}
|
||||
return accountIDs, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateMembershipRole updates a membership role.
|
||||
func (r *TenantRegistry) UpdateMembershipRole(accountID, userID string, role MemberRole) error {
|
||||
res, err := r.db.Exec(`UPDATE account_memberships SET role = ? WHERE account_id = ? AND user_id = ?`, string(role), accountID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update membership role: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMembership deletes a membership record.
|
||||
func (r *TenantRegistry) DeleteMembership(accountID, userID string) error {
|
||||
res, err := r.db.Exec(`DELETE FROM account_memberships WHERE account_id = ? AND user_id = ?`, accountID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete membership: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateStripeAccount inserts a new StripeAccount mapping row.
|
||||
func (r *TenantRegistry) CreateStripeAccount(sa *StripeAccount) error {
|
||||
if sa == nil {
|
||||
return fmt.Errorf("stripe account is nil")
|
||||
}
|
||||
sa.AccountID = strings.TrimSpace(sa.AccountID)
|
||||
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
|
||||
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
|
||||
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
|
||||
sa.PlanVersion = strings.TrimSpace(sa.PlanVersion)
|
||||
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
|
||||
|
||||
if sa.AccountID == "" {
|
||||
return fmt.Errorf("missing account id")
|
||||
}
|
||||
if sa.StripeCustomerID == "" {
|
||||
return fmt.Errorf("missing stripe customer id")
|
||||
}
|
||||
if sa.SubscriptionState == "" {
|
||||
sa.SubscriptionState = "trial"
|
||||
}
|
||||
if sa.UpdatedAt == 0 {
|
||||
sa.UpdatedAt = time.Now().UTC().Unix()
|
||||
}
|
||||
|
||||
_, err := r.db.Exec(`
|
||||
INSERT INTO stripe_accounts (
|
||||
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
||||
plan_version, subscription_state, trial_ends_at, current_period_end, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
sa.AccountID,
|
||||
sa.StripeCustomerID,
|
||||
nullableString(sa.StripeSubscriptionID),
|
||||
nullableString(sa.StripeSubItemWorkspacesID),
|
||||
sa.PlanVersion,
|
||||
sa.SubscriptionState,
|
||||
nullableInt64Ptr(sa.TrialEndsAt),
|
||||
nullableInt64Ptr(sa.CurrentPeriodEnd),
|
||||
sa.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create stripe account: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStripeAccount retrieves the StripeAccount row by account ID.
|
||||
func (r *TenantRegistry) GetStripeAccount(accountID string) (*StripeAccount, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
||||
plan_version, subscription_state, trial_ends_at, current_period_end, updated_at
|
||||
FROM stripe_accounts WHERE account_id = ?`, strings.TrimSpace(accountID))
|
||||
return scanStripeAccount(row)
|
||||
}
|
||||
|
||||
// GetStripeAccountByCustomerID retrieves the StripeAccount row by Stripe customer ID.
|
||||
func (r *TenantRegistry) GetStripeAccountByCustomerID(customerID string) (*StripeAccount, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
||||
plan_version, subscription_state, trial_ends_at, current_period_end, updated_at
|
||||
FROM stripe_accounts WHERE stripe_customer_id = ?`, strings.TrimSpace(customerID))
|
||||
return scanStripeAccount(row)
|
||||
}
|
||||
|
||||
// UpdateStripeAccount modifies an existing StripeAccount row.
|
||||
func (r *TenantRegistry) UpdateStripeAccount(sa *StripeAccount) error {
|
||||
if sa == nil {
|
||||
return fmt.Errorf("stripe account is nil")
|
||||
}
|
||||
sa.AccountID = strings.TrimSpace(sa.AccountID)
|
||||
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
|
||||
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
|
||||
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
|
||||
sa.PlanVersion = strings.TrimSpace(sa.PlanVersion)
|
||||
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
|
||||
|
||||
if sa.AccountID == "" {
|
||||
return fmt.Errorf("missing account id")
|
||||
}
|
||||
if sa.StripeCustomerID == "" {
|
||||
return fmt.Errorf("missing stripe customer id")
|
||||
}
|
||||
if sa.SubscriptionState == "" {
|
||||
sa.SubscriptionState = "trial"
|
||||
}
|
||||
|
||||
sa.UpdatedAt = time.Now().UTC().Unix()
|
||||
|
||||
res, err := r.db.Exec(`
|
||||
UPDATE stripe_accounts SET
|
||||
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_sub_item_workspaces_id = ?,
|
||||
plan_version = ?, subscription_state = ?, trial_ends_at = ?, current_period_end = ?, updated_at = ?
|
||||
WHERE account_id = ?`,
|
||||
sa.StripeCustomerID,
|
||||
nullableString(sa.StripeSubscriptionID),
|
||||
nullableString(sa.StripeSubItemWorkspacesID),
|
||||
sa.PlanVersion,
|
||||
sa.SubscriptionState,
|
||||
nullableInt64Ptr(sa.TrialEndsAt),
|
||||
nullableInt64Ptr(sa.CurrentPeriodEnd),
|
||||
sa.UpdatedAt,
|
||||
sa.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update stripe account: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("stripe account %q not found", sa.AccountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordStripeEvent inserts a webhook event record and returns true if the
|
||||
// event was already recorded (duplicate Stripe delivery).
|
||||
func (r *TenantRegistry) RecordStripeEvent(eventID, eventType string) (alreadyProcessed bool, err error) {
|
||||
eventID = strings.TrimSpace(eventID)
|
||||
eventType = strings.TrimSpace(eventType)
|
||||
if eventID == "" {
|
||||
return false, fmt.Errorf("missing stripe event id")
|
||||
}
|
||||
if eventType == "" {
|
||||
return false, fmt.Errorf("missing stripe event type")
|
||||
}
|
||||
|
||||
// INSERT OR IGNORE avoids driver-specific error parsing for duplicates.
|
||||
res, err := r.db.Exec(`
|
||||
INSERT OR IGNORE INTO stripe_events (
|
||||
stripe_event_id, event_type, received_at, processed_at, processing_error
|
||||
) VALUES (?, ?, ?, NULL, NULL)`,
|
||||
eventID, eventType, time.Now().UTC().Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("record stripe event: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// MarkStripeEventProcessed marks a previously recorded event as processed.
|
||||
// processingError is stored (nullable) for troubleshooting.
|
||||
func (r *TenantRegistry) MarkStripeEventProcessed(eventID string, processingError string) error {
|
||||
eventID = strings.TrimSpace(eventID)
|
||||
if eventID == "" {
|
||||
return fmt.Errorf("missing stripe event id")
|
||||
}
|
||||
processingError = strings.TrimSpace(processingError)
|
||||
|
||||
res, err := r.db.Exec(`
|
||||
UPDATE stripe_events SET
|
||||
processed_at = ?, processing_error = ?
|
||||
WHERE stripe_event_id = ?`,
|
||||
time.Now().UTC().Unix(),
|
||||
nullableString(processingError),
|
||||
eventID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mark stripe event processed: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("stripe event %q not found", eventID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanAccount(s scanner) (*Account, error) {
|
||||
var a Account
|
||||
var kind string
|
||||
var createdAt, updatedAt int64
|
||||
if err := s.Scan(&a.ID, &kind, &a.DisplayName, &createdAt, &updatedAt); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("scan account: %w", err)
|
||||
}
|
||||
a.Kind = AccountKind(kind)
|
||||
a.CreatedAt = time.Unix(createdAt, 0).UTC()
|
||||
a.UpdatedAt = time.Unix(updatedAt, 0).UTC()
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func scanAccounts(rows *sql.Rows) ([]*Account, error) {
|
||||
var accounts []*Account
|
||||
for rows.Next() {
|
||||
a, err := scanAccount(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accounts = append(accounts, a)
|
||||
}
|
||||
return accounts, rows.Err()
|
||||
}
|
||||
|
||||
func scanUser(s scanner) (*User, error) {
|
||||
var u User
|
||||
var createdAt int64
|
||||
var lastLogin sql.NullInt64
|
||||
if err := s.Scan(&u.ID, &u.Email, &createdAt, &lastLogin); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
u.CreatedAt = time.Unix(createdAt, 0).UTC()
|
||||
if lastLogin.Valid {
|
||||
ts := time.Unix(lastLogin.Int64, 0).UTC()
|
||||
u.LastLoginAt = &ts
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func scanMembership(s scanner) (*AccountMembership, error) {
|
||||
var m AccountMembership
|
||||
var role string
|
||||
var createdAt int64
|
||||
if err := s.Scan(&m.AccountID, &m.UserID, &role, &createdAt); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("scan membership: %w", err)
|
||||
}
|
||||
m.Role = MemberRole(role)
|
||||
m.CreatedAt = time.Unix(createdAt, 0).UTC()
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func scanMemberships(rows *sql.Rows) ([]*AccountMembership, error) {
|
||||
var memberships []*AccountMembership
|
||||
for rows.Next() {
|
||||
m, err := scanMembership(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memberships = append(memberships, m)
|
||||
}
|
||||
return memberships, rows.Err()
|
||||
}
|
||||
|
||||
func nullableTimeUnix(t *time.Time) any {
|
||||
if t == nil {
|
||||
return nil
|
||||
@@ -286,9 +864,60 @@ func nullableTimeUnix(t *time.Time) any {
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func nullableInt64Ptr(v *int64) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func nullableString(s string) any {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func scanStripeAccount(s scanner) (*StripeAccount, error) {
|
||||
var sa StripeAccount
|
||||
var subID, subItemID sql.NullString
|
||||
var trialEnds, periodEnd sql.NullInt64
|
||||
if err := s.Scan(
|
||||
&sa.AccountID,
|
||||
&sa.StripeCustomerID,
|
||||
&subID,
|
||||
&subItemID,
|
||||
&sa.PlanVersion,
|
||||
&sa.SubscriptionState,
|
||||
&trialEnds,
|
||||
&periodEnd,
|
||||
&sa.UpdatedAt,
|
||||
); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("scan stripe account: %w", err)
|
||||
}
|
||||
if subID.Valid {
|
||||
sa.StripeSubscriptionID = subID.String
|
||||
}
|
||||
if subItemID.Valid {
|
||||
sa.StripeSubItemWorkspacesID = subItemID.String
|
||||
}
|
||||
if trialEnds.Valid {
|
||||
v := trialEnds.Int64
|
||||
sa.TrialEndsAt = &v
|
||||
}
|
||||
if periodEnd.Valid {
|
||||
v := periodEnd.Int64
|
||||
sa.CurrentPeriodEnd = &v
|
||||
}
|
||||
return &sa, nil
|
||||
}
|
||||
|
||||
@@ -59,6 +59,32 @@ func TestGenerateTenantID_CrockfordCharset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAccountID(t *testing.T) {
|
||||
id, err := GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccountID: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(id, "a_") {
|
||||
t.Errorf("expected prefix a_, got %q", id)
|
||||
}
|
||||
if len(id) != 12 { // "a_" + 10 chars
|
||||
t.Errorf("expected length 12, got %d (%q)", len(id), id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateUserID(t *testing.T) {
|
||||
id, err := GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateUserID: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(id, "u_") {
|
||||
t.Errorf("expected prefix u_, got %q", id)
|
||||
}
|
||||
if len(id) != 12 { // "u_" + 10 chars
|
||||
t.Errorf("expected length 12, got %d (%q)", len(id), id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCRUD(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
@@ -137,6 +163,226 @@ func TestCRUD(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountCRUD(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
accountID, err := GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a := &Account{
|
||||
ID: accountID,
|
||||
Kind: AccountKindMSP,
|
||||
DisplayName: "Test MSP",
|
||||
}
|
||||
|
||||
// Create
|
||||
if err := reg.CreateAccount(a); err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
if a.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should be set")
|
||||
}
|
||||
if a.UpdatedAt.IsZero() {
|
||||
t.Error("UpdatedAt should be set")
|
||||
}
|
||||
|
||||
// Get
|
||||
got, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccount: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("GetAccount returned nil")
|
||||
}
|
||||
if got.Kind != AccountKindMSP {
|
||||
t.Errorf("Kind = %q, want %q", got.Kind, AccountKindMSP)
|
||||
}
|
||||
if got.DisplayName != "Test MSP" {
|
||||
t.Errorf("DisplayName = %q, want %q", got.DisplayName, "Test MSP")
|
||||
}
|
||||
|
||||
// Update
|
||||
got.DisplayName = "Renamed MSP"
|
||||
if err := reg.UpdateAccount(got); err != nil {
|
||||
t.Fatalf("UpdateAccount: %v", err)
|
||||
}
|
||||
got2, err := reg.GetAccount(accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccount after update: %v", err)
|
||||
}
|
||||
if got2.DisplayName != "Renamed MSP" {
|
||||
t.Errorf("DisplayName after update = %q, want %q", got2.DisplayName, "Renamed MSP")
|
||||
}
|
||||
|
||||
// List
|
||||
accounts, err := reg.ListAccounts()
|
||||
if err != nil {
|
||||
t.Fatalf("ListAccounts: %v", err)
|
||||
}
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
if accounts[0].ID != accountID {
|
||||
t.Errorf("accounts[0].ID = %q, want %q", accounts[0].ID, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserCRUD(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
userID, err := GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
u := &User{
|
||||
ID: userID,
|
||||
Email: "user@example.com",
|
||||
}
|
||||
|
||||
// Create
|
||||
if err := reg.CreateUser(u); err != nil {
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
if u.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should be set")
|
||||
}
|
||||
|
||||
// Get by ID
|
||||
got, err := reg.GetUser(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetUser: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("GetUser returned nil")
|
||||
}
|
||||
if got.Email != "user@example.com" {
|
||||
t.Errorf("Email = %q, want %q", got.Email, "user@example.com")
|
||||
}
|
||||
if got.LastLoginAt != nil {
|
||||
t.Errorf("LastLoginAt = %v, want nil", got.LastLoginAt)
|
||||
}
|
||||
|
||||
// Get by email
|
||||
got2, err := reg.GetUserByEmail("user@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserByEmail: %v", err)
|
||||
}
|
||||
if got2 == nil || got2.ID != userID {
|
||||
t.Fatalf("GetUserByEmail returned %+v, want id=%q", got2, userID)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
before := time.Now().UTC()
|
||||
if err := reg.UpdateUserLastLogin(userID); err != nil {
|
||||
t.Fatalf("UpdateUserLastLogin: %v", err)
|
||||
}
|
||||
after := time.Now().UTC()
|
||||
|
||||
got3, err := reg.GetUser(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetUser after last login update: %v", err)
|
||||
}
|
||||
if got3.LastLoginAt == nil {
|
||||
t.Fatal("LastLoginAt should be set")
|
||||
}
|
||||
ll := *got3.LastLoginAt
|
||||
if ll.Before(before.Add(-2*time.Second)) || ll.After(after.Add(2*time.Second)) {
|
||||
t.Fatalf("LastLoginAt=%s out of expected range [%s, %s]", ll, before, after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMembershipCRUD(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
accountID, err := GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
userID, err := GenerateUserID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := reg.CreateAccount(&Account{ID: accountID, Kind: AccountKindMSP, DisplayName: "Account"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateUser(&User{ID: userID, Email: "member@example.com"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := &AccountMembership{
|
||||
AccountID: accountID,
|
||||
UserID: userID,
|
||||
Role: MemberRoleOwner,
|
||||
}
|
||||
|
||||
// Create
|
||||
if err := reg.CreateMembership(m); err != nil {
|
||||
t.Fatalf("CreateMembership: %v", err)
|
||||
}
|
||||
if m.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should be set")
|
||||
}
|
||||
|
||||
// Get
|
||||
got, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMembership: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("GetMembership returned nil")
|
||||
}
|
||||
if got.Role != MemberRoleOwner {
|
||||
t.Errorf("Role = %q, want %q", got.Role, MemberRoleOwner)
|
||||
}
|
||||
|
||||
// List by account
|
||||
members, err := reg.ListMembersByAccount(accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListMembersByAccount: %v", err)
|
||||
}
|
||||
if len(members) != 1 {
|
||||
t.Fatalf("expected 1 member, got %d", len(members))
|
||||
}
|
||||
if members[0].UserID != userID {
|
||||
t.Errorf("members[0].UserID = %q, want %q", members[0].UserID, userID)
|
||||
}
|
||||
|
||||
// List accounts by user
|
||||
accounts, err := reg.ListAccountsByUser(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAccountsByUser: %v", err)
|
||||
}
|
||||
if len(accounts) != 1 || accounts[0] != accountID {
|
||||
t.Fatalf("accounts=%v, want [%q]", accounts, accountID)
|
||||
}
|
||||
|
||||
// Update role
|
||||
if err := reg.UpdateMembershipRole(accountID, userID, MemberRoleAdmin); err != nil {
|
||||
t.Fatalf("UpdateMembershipRole: %v", err)
|
||||
}
|
||||
got2, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMembership after role update: %v", err)
|
||||
}
|
||||
if got2.Role != MemberRoleAdmin {
|
||||
t.Errorf("Role after update = %q, want %q", got2.Role, MemberRoleAdmin)
|
||||
}
|
||||
|
||||
// Delete
|
||||
if err := reg.DeleteMembership(accountID, userID); err != nil {
|
||||
t.Fatalf("DeleteMembership: %v", err)
|
||||
}
|
||||
got3, err := reg.GetMembership(accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMembership after delete: %v", err)
|
||||
}
|
||||
if got3 != nil {
|
||||
t.Fatalf("expected nil membership after delete, got %+v", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
@@ -199,6 +445,46 @@ func TestListByState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByAccountID(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
accountID, err := GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(&Account{ID: accountID, Kind: AccountKindMSP, DisplayName: "Account"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := reg.Create(&Tenant{ID: "t-ACCNT0001", AccountID: accountID, State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{ID: "t-ACCNT0002", AccountID: accountID, State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{ID: "t-ACCNT0003", State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tenants, err := reg.ListByAccountID(accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByAccountID: %v", err)
|
||||
}
|
||||
if len(tenants) != 2 {
|
||||
t.Fatalf("expected 2 tenants, got %d", len(tenants))
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, tnt := range tenants {
|
||||
seen[tnt.ID] = true
|
||||
if tnt.AccountID != accountID {
|
||||
t.Errorf("tenant %s AccountID=%q, want %q", tnt.ID, tnt.AccountID, accountID)
|
||||
}
|
||||
}
|
||||
if !seen["t-ACCNT0001"] || !seen["t-ACCNT0002"] {
|
||||
t.Fatalf("expected tenants t-ACCNT0001 and t-ACCNT0002, got %+v", tenants)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountByState(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
@@ -280,3 +566,107 @@ func TestNewTenantRegistry_InvalidDir(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripeAccountCRUD(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
accountID, err := GenerateAccountID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.CreateAccount(&Account{
|
||||
ID: accountID,
|
||||
Kind: AccountKindMSP,
|
||||
DisplayName: "Test MSP",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
trialEnds := time.Now().UTC().Add(7 * 24 * time.Hour).Unix()
|
||||
periodEnd := time.Now().UTC().Add(30 * 24 * time.Hour).Unix()
|
||||
|
||||
sa := &StripeAccount{
|
||||
AccountID: accountID,
|
||||
StripeCustomerID: "cus_test_123",
|
||||
StripeSubscriptionID: "sub_test_123",
|
||||
StripeSubItemWorkspacesID: "si_workspaces_123",
|
||||
PlanVersion: "msp_hosted_v1",
|
||||
SubscriptionState: "trial",
|
||||
TrialEndsAt: &trialEnds,
|
||||
CurrentPeriodEnd: &periodEnd,
|
||||
}
|
||||
|
||||
// Create
|
||||
if err := reg.CreateStripeAccount(sa); err != nil {
|
||||
t.Fatalf("CreateStripeAccount: %v", err)
|
||||
}
|
||||
|
||||
// Get by account id
|
||||
got, err := reg.GetStripeAccount(accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetStripeAccount: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("GetStripeAccount returned nil")
|
||||
}
|
||||
if got.StripeCustomerID != "cus_test_123" {
|
||||
t.Errorf("StripeCustomerID = %q, want %q", got.StripeCustomerID, "cus_test_123")
|
||||
}
|
||||
if got.StripeSubscriptionID != "sub_test_123" {
|
||||
t.Errorf("StripeSubscriptionID = %q, want %q", got.StripeSubscriptionID, "sub_test_123")
|
||||
}
|
||||
|
||||
// Get by customer id
|
||||
got2, err := reg.GetStripeAccountByCustomerID("cus_test_123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetStripeAccountByCustomerID: %v", err)
|
||||
}
|
||||
if got2 == nil || got2.AccountID != accountID {
|
||||
t.Fatalf("expected accountID %q, got %#v", accountID, got2)
|
||||
}
|
||||
|
||||
// Update
|
||||
got2.SubscriptionState = "active"
|
||||
got2.PlanVersion = "msp_hosted_v2"
|
||||
got2.StripeSubscriptionID = "sub_test_456"
|
||||
if err := reg.UpdateStripeAccount(got2); err != nil {
|
||||
t.Fatalf("UpdateStripeAccount: %v", err)
|
||||
}
|
||||
|
||||
got3, err := reg.GetStripeAccount(accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetStripeAccount after update: %v", err)
|
||||
}
|
||||
if got3.SubscriptionState != "active" {
|
||||
t.Errorf("SubscriptionState = %q, want %q", got3.SubscriptionState, "active")
|
||||
}
|
||||
if got3.PlanVersion != "msp_hosted_v2" {
|
||||
t.Errorf("PlanVersion = %q, want %q", got3.PlanVersion, "msp_hosted_v2")
|
||||
}
|
||||
if got3.StripeSubscriptionID != "sub_test_456" {
|
||||
t.Errorf("StripeSubscriptionID = %q, want %q", got3.StripeSubscriptionID, "sub_test_456")
|
||||
}
|
||||
if got3.UpdatedAt == 0 {
|
||||
t.Error("UpdatedAt should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripeEventIdempotency(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
already, err := reg.RecordStripeEvent("evt_test_123", "customer.subscription.updated")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordStripeEvent: %v", err)
|
||||
}
|
||||
if already {
|
||||
t.Fatalf("expected alreadyProcessed=false on first insert")
|
||||
}
|
||||
|
||||
already2, err := reg.RecordStripeEvent("evt_test_123", "customer.subscription.updated")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordStripeEvent duplicate: %v", err)
|
||||
}
|
||||
if !already2 {
|
||||
t.Fatalf("expected alreadyProcessed=true on duplicate insert")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
|
||||
cpauth "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/auth"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/docker"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/handoff"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/portal"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe"
|
||||
)
|
||||
@@ -99,4 +101,12 @@ func RegisterRoutes(mux *http.ServeMux, deps *Deps) {
|
||||
|
||||
mux.Handle("/api/accounts/{account_id}/tenants", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenantsCollection))
|
||||
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenant))
|
||||
|
||||
// Tenant switching handoff (admin-key authenticated for now; session auth in M-4)
|
||||
handoffHandler := handoff.HandleHandoff(deps.Registry, deps.Config.TenantsDir())
|
||||
mux.Handle("/api/accounts/{account_id}/tenants/{tenant_id}/handoff", admin.AdminKeyMiddleware(deps.Config.AdminKey, handoffHandler))
|
||||
|
||||
// MSP portal API (admin-key authenticated for now; session auth in M-4)
|
||||
mux.Handle("/api/portal/dashboard", admin.AdminKeyMiddleware(deps.Config.AdminKey, portal.HandlePortalDashboard(deps.Registry)))
|
||||
mux.Handle("/api/portal/workspaces/{tenant_id}", admin.AdminKeyMiddleware(deps.Config.AdminKey, portal.HandlePortalWorkspaceDetail(deps.Registry)))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
"github.com/rs/zerolog/log"
|
||||
stripelib "github.com/stripe/stripe-go/v82"
|
||||
"github.com/stripe/stripe-go/v82/webhook"
|
||||
@@ -98,14 +99,14 @@ func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) er
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
return fmt.Errorf("decode subscription: %w", err)
|
||||
}
|
||||
return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub)
|
||||
return h.routeSubscriptionUpdated(r, sub)
|
||||
|
||||
case "customer.subscription.deleted":
|
||||
var sub Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
return fmt.Errorf("decode subscription: %w", err)
|
||||
}
|
||||
return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub)
|
||||
return h.routeSubscriptionDeleted(r, sub)
|
||||
|
||||
default:
|
||||
log.Info().
|
||||
@@ -116,6 +117,46 @@ func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) er
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WebhookHandler) routeSubscriptionUpdated(r *http.Request, sub Subscription) error {
|
||||
customerID := strings.TrimSpace(sub.Customer)
|
||||
if customerID != "" {
|
||||
sa, err := h.provisioner.registry.GetStripeAccountByCustomerID(customerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup stripe account by customer: %w", err)
|
||||
}
|
||||
if sa != nil {
|
||||
acct, err := h.provisioner.registry.GetAccount(sa.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup account: %w", err)
|
||||
}
|
||||
if acct != nil && acct.Kind == registry.AccountKindMSP {
|
||||
return h.provisioner.HandleMSPSubscriptionUpdated(r.Context(), sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub)
|
||||
}
|
||||
|
||||
func (h *WebhookHandler) routeSubscriptionDeleted(r *http.Request, sub Subscription) error {
|
||||
customerID := strings.TrimSpace(sub.Customer)
|
||||
if customerID != "" {
|
||||
sa, err := h.provisioner.registry.GetStripeAccountByCustomerID(customerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup stripe account by customer: %w", err)
|
||||
}
|
||||
if sa != nil {
|
||||
acct, err := h.provisioner.registry.GetAccount(sa.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup account: %w", err)
|
||||
}
|
||||
if acct != nil && acct.Kind == registry.AccountKindMSP {
|
||||
return h.provisioner.HandleMSPSubscriptionDeleted(r.Context(), sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub)
|
||||
}
|
||||
|
||||
// CheckoutSession is a minimal representation of a Stripe checkout.session event.
|
||||
type CheckoutSession struct {
|
||||
ID string `json:"id"`
|
||||
|
||||
Reference in New Issue
Block a user