mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
feat: add cmd/pulse-control-plane binary for Cloud tenant lifecycle
Implements the control plane for Pulse Cloud's container-per-tenant architecture. Manages tenant provisioning via Stripe webhooks, billing state via FileBillingStore, Docker container orchestration with Traefik labels, admin API with key auth, and background health monitoring.
This commit is contained in:
6
Makefile
6
Makefile
@@ -1,6 +1,6 @@
|
||||
# Pulse Makefile for development
|
||||
|
||||
.PHONY: build run dev frontend backend all clean distclean dev-hot lint lint-backend lint-frontend format format-backend format-frontend build-agents
|
||||
.PHONY: build run dev frontend backend all clean distclean dev-hot lint lint-backend lint-frontend format format-backend format-frontend build-agents control-plane
|
||||
|
||||
FRONTEND_DIR := frontend-modern
|
||||
FRONTEND_DIST := $(FRONTEND_DIR)/dist
|
||||
@@ -69,6 +69,10 @@ format-backend:
|
||||
format-frontend:
|
||||
npm --prefix $(FRONTEND_DIR) run format
|
||||
|
||||
# Build control plane binary
|
||||
control-plane:
|
||||
go build -o pulse-control-plane ./cmd/pulse-control-plane
|
||||
|
||||
test:
|
||||
@./scripts/ensure_test_assets.sh
|
||||
@echo "Running backend tests (excluding tmp tooling)..."
|
||||
|
||||
49
cmd/pulse-control-plane/main.go
Normal file
49
cmd/pulse-control-plane/main.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "dev"
|
||||
BuildTime = "unknown"
|
||||
GitCommit = "unknown"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "pulse-control-plane",
|
||||
Short: "Pulse Cloud Control Plane",
|
||||
Long: `Control plane for Pulse Cloud — manages tenant lifecycle, containers, and billing.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cloudcp.Run(context.Background(), Version)
|
||||
},
|
||||
}
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print version information",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("Pulse Control Plane %s\n", Version)
|
||||
if BuildTime != "unknown" {
|
||||
fmt.Printf("Built: %s\n", BuildTime)
|
||||
}
|
||||
if GitCommit != "unknown" {
|
||||
fmt.Printf("Commit: %s\n", GitCommit)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
70
internal/cloudcp/admin/handlers.go
Normal file
70
internal/cloudcp/admin/handlers.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
// HandleListTenants returns an authenticated handler that lists all tenants.
|
||||
func HandleListTenants(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Optional state filter
|
||||
stateFilter := strings.TrimSpace(r.URL.Query().Get("state"))
|
||||
|
||||
var tenants []*registry.Tenant
|
||||
var err error
|
||||
|
||||
if stateFilter != "" {
|
||||
tenants, err = reg.ListByState(registry.TenantState(stateFilter))
|
||||
} else {
|
||||
tenants, err = reg.List()
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if tenants == nil {
|
||||
tenants = []*registry.Tenant{}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"tenants": tenants,
|
||||
"count": len(tenants),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// AdminKeyMiddleware returns middleware that requires a valid admin API key.
|
||||
func AdminKeyMiddleware(adminKey string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := strings.TrimSpace(r.Header.Get("X-Admin-Key"))
|
||||
if key == "" {
|
||||
// Also check Authorization: Bearer <key>
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
key = strings.TrimSpace(strings.TrimPrefix(auth, "Bearer "))
|
||||
}
|
||||
}
|
||||
|
||||
if key == "" || key != adminKey {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
64
internal/cloudcp/admin/status.go
Normal file
64
internal/cloudcp/admin/status.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
// HandleHealthz returns 200 "ok" unconditionally (liveness probe).
|
||||
func HandleHealthz(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
// HandleReadyz returns a handler that checks database connectivity (readiness probe).
|
||||
func HandleReadyz(reg *registry.TenantRegistry) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := reg.Ping(); err != nil {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte("not ready"))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStatus returns a handler that reports aggregate tenant status.
|
||||
func HandleStatus(reg *registry.TenantRegistry, version string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
counts, err := reg.CountByState()
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
total := 0
|
||||
for _, c := range counts {
|
||||
total += c
|
||||
}
|
||||
|
||||
healthy, unhealthy, err := reg.HealthSummary()
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"version": version,
|
||||
"total_tenants": total,
|
||||
"healthy": healthy,
|
||||
"unhealthy": unhealthy,
|
||||
"by_state": counts,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}
|
||||
130
internal/cloudcp/admin/status_test.go
Normal file
130
internal/cloudcp/admin/status_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
)
|
||||
|
||||
func newTestRegistry(t *testing.T) *registry.TenantRegistry {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
reg, err := registry.NewTenantRegistry(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTenantRegistry: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = reg.Close() })
|
||||
return reg
|
||||
}
|
||||
|
||||
func TestHandleHealthz(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
HandleHealthz(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
if rec.Body.String() != "ok" {
|
||||
t.Errorf("body = %q, want %q", rec.Body.String(), "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReadyz(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
handler := HandleReadyz(reg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
if rec.Body.String() != "ready" {
|
||||
t.Errorf("body = %q, want %q", rec.Body.String(), "ready")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStatus(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
// Seed data
|
||||
if err := reg.Create(®istry.Tenant{
|
||||
ID: "t-STATUS001", State: registry.TenantStateActive, HealthCheckOK: true,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
handler := HandleStatus(reg, "test-version")
|
||||
req := httptest.NewRequest(http.MethodGet, "/status", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if resp["version"] != "test-version" {
|
||||
t.Errorf("version = %v, want test-version", resp["version"])
|
||||
}
|
||||
if resp["total_tenants"] != float64(1) {
|
||||
t.Errorf("total_tenants = %v, want 1", resp["total_tenants"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminKeyMiddleware(t *testing.T) {
|
||||
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("authorized"))
|
||||
})
|
||||
|
||||
handler := AdminKeyMiddleware("secret-key", inner)
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/tenants", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong key", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/tenants", nil)
|
||||
req.Header.Set("X-Admin-Key", "wrong")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("correct X-Admin-Key", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/tenants", nil)
|
||||
req.Header.Set("X-Admin-Key", "secret-key")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("correct Bearer token", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/tenants", nil)
|
||||
req.Header.Set("Authorization", "Bearer secret-key")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
104
internal/cloudcp/config.go
Normal file
104
internal/cloudcp/config.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package cloudcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// CPConfig holds all configuration for the control plane.
|
||||
type CPConfig struct {
|
||||
DataDir string
|
||||
BindAddress string
|
||||
Port int
|
||||
AdminKey string
|
||||
BaseURL string
|
||||
PulseImage string
|
||||
DockerNetwork string
|
||||
TenantMemoryLimit int64 // bytes
|
||||
TenantCPUShares int64
|
||||
StripeWebhookSecret string
|
||||
StripeAPIKey string
|
||||
}
|
||||
|
||||
// TenantsDir returns the directory where per-tenant data is stored.
|
||||
func (c *CPConfig) TenantsDir() string {
|
||||
return filepath.Join(c.DataDir, "tenants")
|
||||
}
|
||||
|
||||
// ControlPlaneDir returns the directory for control plane's own data (registry DB, etc).
|
||||
func (c *CPConfig) ControlPlaneDir() string {
|
||||
return filepath.Join(c.DataDir, "control-plane")
|
||||
}
|
||||
|
||||
// LoadConfig loads control plane configuration from environment variables.
|
||||
// A .env file is loaded if present but not required.
|
||||
func LoadConfig() (*CPConfig, error) {
|
||||
// Best-effort .env loading (not required)
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := &CPConfig{
|
||||
DataDir: envOrDefault("CP_DATA_DIR", "/data"),
|
||||
BindAddress: envOrDefault("CP_BIND_ADDRESS", "0.0.0.0"),
|
||||
Port: envOrDefaultInt("CP_PORT", 8443),
|
||||
AdminKey: strings.TrimSpace(os.Getenv("CP_ADMIN_KEY")),
|
||||
BaseURL: strings.TrimSpace(os.Getenv("CP_BASE_URL")),
|
||||
PulseImage: envOrDefault("CP_PULSE_IMAGE", "ghcr.io/rcourtman/pulse:latest"),
|
||||
DockerNetwork: envOrDefault("CP_DOCKER_NETWORK", "pulse-cloud"),
|
||||
TenantMemoryLimit: envOrDefaultInt64("CP_TENANT_MEMORY_LIMIT", 512*1024*1024), // 512 MiB
|
||||
TenantCPUShares: envOrDefaultInt64("CP_TENANT_CPU_SHARES", 256),
|
||||
StripeWebhookSecret: strings.TrimSpace(os.Getenv("STRIPE_WEBHOOK_SECRET")),
|
||||
StripeAPIKey: strings.TrimSpace(os.Getenv("STRIPE_API_KEY")),
|
||||
}
|
||||
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *CPConfig) validate() error {
|
||||
var missing []string
|
||||
if c.AdminKey == "" {
|
||||
missing = append(missing, "CP_ADMIN_KEY")
|
||||
}
|
||||
if c.BaseURL == "" {
|
||||
missing = append(missing, "CP_BASE_URL")
|
||||
}
|
||||
if c.StripeWebhookSecret == "" {
|
||||
missing = append(missing, "STRIPE_WEBHOOK_SECRET")
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("missing required environment variables: %s", strings.Join(missing, ", "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func envOrDefault(key, fallback string) string {
|
||||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envOrDefaultInt(key string, fallback int) int {
|
||||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envOrDefaultInt64(key string, fallback int64) int64 {
|
||||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||||
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
95
internal/cloudcp/config_test.go
Normal file
95
internal/cloudcp/config_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package cloudcp
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfig_MissingRequired(t *testing.T) {
|
||||
// Clear relevant env vars
|
||||
for _, key := range []string{
|
||||
"CP_ADMIN_KEY", "CP_BASE_URL", "STRIPE_WEBHOOK_SECRET",
|
||||
"CP_DATA_DIR", "CP_BIND_ADDRESS", "CP_PORT",
|
||||
} {
|
||||
t.Setenv(key, "")
|
||||
}
|
||||
|
||||
_, err := LoadConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing required vars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_AllRequired(t *testing.T) {
|
||||
t.Setenv("CP_ADMIN_KEY", "test-key")
|
||||
t.Setenv("CP_BASE_URL", "https://cloud.example.com")
|
||||
t.Setenv("STRIPE_WEBHOOK_SECRET", "whsec_test")
|
||||
|
||||
// Clear optional vars to use defaults
|
||||
for _, key := range []string{
|
||||
"CP_DATA_DIR", "CP_BIND_ADDRESS", "CP_PORT",
|
||||
"CP_PULSE_IMAGE", "CP_DOCKER_NETWORK",
|
||||
"CP_TENANT_MEMORY_LIMIT", "CP_TENANT_CPU_SHARES",
|
||||
"STRIPE_API_KEY",
|
||||
} {
|
||||
os.Unsetenv(key)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
if cfg.AdminKey != "test-key" {
|
||||
t.Errorf("AdminKey = %q, want %q", cfg.AdminKey, "test-key")
|
||||
}
|
||||
if cfg.BaseURL != "https://cloud.example.com" {
|
||||
t.Errorf("BaseURL = %q", cfg.BaseURL)
|
||||
}
|
||||
if cfg.Port != 8443 {
|
||||
t.Errorf("Port = %d, want 8443", cfg.Port)
|
||||
}
|
||||
if cfg.DataDir != "/data" {
|
||||
t.Errorf("DataDir = %q, want /data", cfg.DataDir)
|
||||
}
|
||||
if cfg.BindAddress != "0.0.0.0" {
|
||||
t.Errorf("BindAddress = %q, want 0.0.0.0", cfg.BindAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_CustomValues(t *testing.T) {
|
||||
t.Setenv("CP_ADMIN_KEY", "key")
|
||||
t.Setenv("CP_BASE_URL", "https://test.example.com")
|
||||
t.Setenv("STRIPE_WEBHOOK_SECRET", "whsec_x")
|
||||
t.Setenv("CP_PORT", "9000")
|
||||
t.Setenv("CP_DATA_DIR", "/custom/data")
|
||||
t.Setenv("CP_BIND_ADDRESS", "127.0.0.1")
|
||||
|
||||
cfg, err := LoadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.Port != 9000 {
|
||||
t.Errorf("Port = %d, want 9000", cfg.Port)
|
||||
}
|
||||
if cfg.DataDir != "/custom/data" {
|
||||
t.Errorf("DataDir = %q", cfg.DataDir)
|
||||
}
|
||||
if cfg.BindAddress != "127.0.0.1" {
|
||||
t.Errorf("BindAddress = %q", cfg.BindAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantsDir(t *testing.T) {
|
||||
cfg := &CPConfig{DataDir: "/data"}
|
||||
if got := cfg.TenantsDir(); got != "/data/tenants" {
|
||||
t.Errorf("TenantsDir = %q, want /data/tenants", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlPlaneDir(t *testing.T) {
|
||||
cfg := &CPConfig{DataDir: "/data"}
|
||||
if got := cfg.ControlPlaneDir(); got != "/data/control-plane" {
|
||||
t.Errorf("ControlPlaneDir = %q, want /data/control-plane", got)
|
||||
}
|
||||
}
|
||||
25
internal/cloudcp/docker/labels.go
Normal file
25
internal/cloudcp/docker/labels.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package docker
|
||||
|
||||
import "fmt"
|
||||
|
||||
// TraefikLabels generates Docker labels for Traefik reverse-proxy routing.
|
||||
// Each tenant gets a subdomain: <tenantID>.cloud.pulserelay.pro
|
||||
func TraefikLabels(tenantID, baseDomain string, containerPort int) map[string]string {
|
||||
svc := "pulse-" + tenantID
|
||||
host := fmt.Sprintf("%s.%s", tenantID, baseDomain)
|
||||
|
||||
return map[string]string{
|
||||
"traefik.enable": "true",
|
||||
|
||||
// HTTP router
|
||||
fmt.Sprintf("traefik.http.routers.%s.rule", svc): fmt.Sprintf("Host(`%s`)", host),
|
||||
fmt.Sprintf("traefik.http.routers.%s.entrypoints", svc): "websecure",
|
||||
fmt.Sprintf("traefik.http.routers.%s.tls.certresolver", svc): "le",
|
||||
|
||||
// Service
|
||||
fmt.Sprintf("traefik.http.services.%s.loadbalancer.server.port", svc): fmt.Sprintf("%d", containerPort),
|
||||
|
||||
// Metadata
|
||||
"pulse.tenant.id": tenantID,
|
||||
}
|
||||
}
|
||||
157
internal/cloudcp/docker/manager.go
Normal file
157
internal/cloudcp/docker/manager.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/mount"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ManagerConfig holds Docker manager settings.
|
||||
type ManagerConfig struct {
|
||||
Image string
|
||||
Network string
|
||||
BaseDomain string
|
||||
MemoryLimit int64 // bytes
|
||||
CPUShares int64
|
||||
ContainerPort int // port inside the container (default 7655)
|
||||
}
|
||||
|
||||
// Manager orchestrates Docker containers for tenant lifecycle.
|
||||
type Manager struct {
|
||||
cli *client.Client
|
||||
cfg ManagerConfig
|
||||
}
|
||||
|
||||
// NewManager creates a Docker manager connected to the local daemon.
|
||||
func NewManager(cfg ManagerConfig) (*Manager, error) {
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create docker client: %w", err)
|
||||
}
|
||||
if cfg.ContainerPort == 0 {
|
||||
cfg.ContainerPort = 7655
|
||||
}
|
||||
return &Manager{cli: cli, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// Close closes the Docker client.
|
||||
func (m *Manager) Close() error {
|
||||
if m.cli != nil {
|
||||
return m.cli.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateAndStart creates and starts a tenant container.
|
||||
// tenantDataDir is the host path that gets bind-mounted to /etc/pulse in the container.
|
||||
func (m *Manager) CreateAndStart(ctx context.Context, tenantID, tenantDataDir string) (containerID string, err error) {
|
||||
labels := TraefikLabels(tenantID, m.cfg.BaseDomain, m.cfg.ContainerPort)
|
||||
labels["pulse.managed"] = "true"
|
||||
|
||||
containerName := "pulse-" + tenantID
|
||||
|
||||
resp, err := m.cli.ContainerCreate(ctx,
|
||||
&container.Config{
|
||||
Image: m.cfg.Image,
|
||||
Labels: labels,
|
||||
Env: []string{
|
||||
"PULSE_DATA_DIR=/etc/pulse",
|
||||
"PULSE_HOSTED_MODE=true",
|
||||
},
|
||||
},
|
||||
&container.HostConfig{
|
||||
RestartPolicy: container.RestartPolicy{Name: "unless-stopped"},
|
||||
Resources: container.Resources{
|
||||
Memory: m.cfg.MemoryLimit,
|
||||
CPUShares: m.cfg.CPUShares,
|
||||
},
|
||||
Mounts: []mount.Mount{
|
||||
{
|
||||
Type: mount.TypeBind,
|
||||
Source: tenantDataDir,
|
||||
Target: "/etc/pulse",
|
||||
},
|
||||
},
|
||||
},
|
||||
&network.NetworkingConfig{
|
||||
EndpointsConfig: map[string]*network.EndpointSettings{
|
||||
m.cfg.Network: {},
|
||||
},
|
||||
},
|
||||
nil, // platform
|
||||
containerName,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create container for %s: %w", tenantID, err)
|
||||
}
|
||||
|
||||
if err := m.cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
return resp.ID, fmt.Errorf("start container for %s: %w", tenantID, err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("tenant_id", tenantID).
|
||||
Str("container_id", resp.ID[:12]).
|
||||
Str("container_name", containerName).
|
||||
Msg("Tenant container started")
|
||||
|
||||
return resp.ID, nil
|
||||
}
|
||||
|
||||
// Stop stops a tenant container gracefully.
|
||||
func (m *Manager) Stop(ctx context.Context, containerID string) error {
|
||||
timeout := 30
|
||||
return m.cli.ContainerStop(ctx, containerID, container.StopOptions{Timeout: &timeout})
|
||||
}
|
||||
|
||||
// Remove removes a stopped tenant container.
|
||||
func (m *Manager) Remove(ctx context.Context, containerID string) error {
|
||||
return m.cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
|
||||
Force: true,
|
||||
})
|
||||
}
|
||||
|
||||
// StopAndRemove stops then removes a tenant container.
|
||||
func (m *Manager) StopAndRemove(ctx context.Context, containerID string) error {
|
||||
if err := m.Stop(ctx, containerID); err != nil {
|
||||
log.Warn().Err(err).Str("container_id", containerID).Msg("Failed to stop container, forcing remove")
|
||||
}
|
||||
return m.Remove(ctx, containerID)
|
||||
}
|
||||
|
||||
// HealthCheck performs an HTTP health check against a running container.
|
||||
// It connects to the container's published port via the Docker network.
|
||||
func (m *Manager) HealthCheck(ctx context.Context, containerID string) (bool, error) {
|
||||
inspect, err := m.cli.ContainerInspect(ctx, containerID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("inspect container: %w", err)
|
||||
}
|
||||
|
||||
if !inspect.State.Running {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Find the container's IP on our network
|
||||
netSettings, ok := inspect.NetworkSettings.Networks[m.cfg.Network]
|
||||
if !ok || netSettings.IPAddress == "" {
|
||||
return false, fmt.Errorf("container not connected to network %s", m.cfg.Network)
|
||||
}
|
||||
|
||||
healthURL := fmt.Sprintf("http://%s:%d/api/health", netSettings.IPAddress, m.cfg.ContainerPort)
|
||||
httpClient := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := httpClient.Get(healthURL)
|
||||
if err != nil {
|
||||
return false, nil // unreachable, not an error condition
|
||||
}
|
||||
defer func() { _, _ = io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
107
internal/cloudcp/health/monitor.go
Normal file
107
internal/cloudcp/health/monitor.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/docker"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MonitorConfig holds health monitor settings.
|
||||
type MonitorConfig struct {
|
||||
Interval time.Duration // how often to check (default 60s)
|
||||
RestartOnFail bool // restart unhealthy containers
|
||||
FailThreshold int // consecutive failures before restart (default 3)
|
||||
}
|
||||
|
||||
// Monitor periodically health-checks active tenant containers and optionally
|
||||
// restarts unhealthy ones.
|
||||
type Monitor struct {
|
||||
registry *registry.TenantRegistry
|
||||
docker *docker.Manager
|
||||
cfg MonitorConfig
|
||||
}
|
||||
|
||||
// NewMonitor creates a health monitor.
|
||||
func NewMonitor(reg *registry.TenantRegistry, mgr *docker.Manager, cfg MonitorConfig) *Monitor {
|
||||
if cfg.Interval == 0 {
|
||||
cfg.Interval = 60 * time.Second
|
||||
}
|
||||
if cfg.FailThreshold == 0 {
|
||||
cfg.FailThreshold = 3
|
||||
}
|
||||
return &Monitor{
|
||||
registry: reg,
|
||||
docker: mgr,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the health check loop. It blocks until ctx is cancelled.
|
||||
func (m *Monitor) Run(ctx context.Context) {
|
||||
log.Info().
|
||||
Dur("interval", m.cfg.Interval).
|
||||
Bool("restart_on_fail", m.cfg.RestartOnFail).
|
||||
Msg("Health monitor started")
|
||||
|
||||
ticker := time.NewTicker(m.cfg.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("Health monitor stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.checkAll(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Monitor) checkAll(ctx context.Context) {
|
||||
tenants, err := m.registry.ListByState(registry.TenantStateActive)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Health monitor: failed to list active tenants")
|
||||
return
|
||||
}
|
||||
|
||||
for _, tenant := range tenants {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if tenant.ContainerID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
healthy, err := m.docker.HealthCheck(ctx, tenant.ContainerID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("tenant_id", tenant.ID).
|
||||
Str("container_id", tenant.ContainerID).
|
||||
Msg("Health check error")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
tenant.LastHealthCheck = &now
|
||||
tenant.HealthCheckOK = healthy
|
||||
|
||||
if err := m.registry.Update(tenant); err != nil {
|
||||
log.Error().Err(err).Str("tenant_id", tenant.ID).Msg("Failed to update health status")
|
||||
continue
|
||||
}
|
||||
|
||||
if !healthy && m.cfg.RestartOnFail {
|
||||
log.Warn().
|
||||
Str("tenant_id", tenant.ID).
|
||||
Str("container_id", tenant.ContainerID).
|
||||
Msg("Container unhealthy, attempting restart")
|
||||
|
||||
if err := m.docker.Stop(ctx, tenant.ContainerID); err != nil {
|
||||
log.Error().Err(err).Str("tenant_id", tenant.ID).Msg("Failed to stop unhealthy container")
|
||||
}
|
||||
// Docker restart policy (unless-stopped) will restart the container
|
||||
}
|
||||
}
|
||||
}
|
||||
56
internal/cloudcp/registry/models.go
Normal file
56
internal/cloudcp/registry/models.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TenantState represents the lifecycle state of a tenant.
|
||||
type TenantState string
|
||||
|
||||
const (
|
||||
TenantStateProvisioning TenantState = "provisioning"
|
||||
TenantStateActive TenantState = "active"
|
||||
TenantStateSuspended TenantState = "suspended"
|
||||
TenantStateCanceled TenantState = "canceled"
|
||||
TenantStateDeleted TenantState = "deleted"
|
||||
)
|
||||
|
||||
// Tenant represents a Cloud tenant record in the registry.
|
||||
type Tenant struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display_name"`
|
||||
State TenantState `json:"state"`
|
||||
StripeCustomerID string `json:"stripe_customer_id"`
|
||||
StripeSubscriptionID string `json:"stripe_subscription_id"`
|
||||
StripePriceID string `json:"stripe_price_id"`
|
||||
PlanVersion string `json:"plan_version"`
|
||||
ContainerID string `json:"container_id"`
|
||||
CurrentImageDigest string `json:"current_image_digest"`
|
||||
DesiredImageDigest string `json:"desired_image_digest"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
LastHealthCheck *time.Time `json:"last_health_check,omitempty"`
|
||||
HealthCheckOK bool `json:"health_check_ok"`
|
||||
}
|
||||
|
||||
// crockfordBase32 is the Crockford base32 alphabet (excludes I, L, O, U).
|
||||
const crockfordBase32 = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||
|
||||
// GenerateTenantID returns a tenant ID of the form "t-" followed by 10 random
|
||||
// Crockford base32 characters (50 bits of entropy).
|
||||
func GenerateTenantID() (string, error) {
|
||||
b := make([]byte, 10)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate tenant id: %w", err)
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString("t-")
|
||||
for _, v := range b {
|
||||
sb.WriteByte(crockfordBase32[int(v)%len(crockfordBase32)])
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
294
internal/cloudcp/registry/registry.go
Normal file
294
internal/cloudcp/registry/registry.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// TenantRegistry provides CRUD operations for tenant records backed by SQLite.
|
||||
type TenantRegistry struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewTenantRegistry opens (or creates) the tenant registry database in dir.
|
||||
func NewTenantRegistry(dir string) (*TenantRegistry, error) {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create registry dir: %w", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(dir, "tenants.db")
|
||||
dsn := dbPath + "?" + url.Values{
|
||||
"_pragma": []string{
|
||||
"busy_timeout(30000)",
|
||||
"journal_mode(WAL)",
|
||||
"synchronous(NORMAL)",
|
||||
},
|
||||
}.Encode()
|
||||
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open tenant registry db: %w", err)
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
r := &TenantRegistry{db: db}
|
||||
if err := r.initSchema(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *TenantRegistry) initSchema() error {
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS tenants (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT NOT NULL DEFAULT '',
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT 'provisioning',
|
||||
stripe_customer_id TEXT NOT NULL DEFAULT '',
|
||||
stripe_subscription_id TEXT NOT NULL DEFAULT '',
|
||||
stripe_price_id TEXT NOT NULL DEFAULT '',
|
||||
plan_version TEXT NOT NULL DEFAULT '',
|
||||
container_id TEXT NOT NULL DEFAULT '',
|
||||
current_image_digest TEXT NOT NULL DEFAULT '',
|
||||
desired_image_digest TEXT NOT NULL DEFAULT '',
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
last_health_check INTEGER,
|
||||
health_check_ok INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_tenants_state ON tenants(state);
|
||||
CREATE INDEX IF NOT EXISTS idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
|
||||
`
|
||||
if _, err := r.db.Exec(schema); err != nil {
|
||||
return fmt.Errorf("init tenant registry schema: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping checks database connectivity (used for readiness probes).
|
||||
func (r *TenantRegistry) Ping() error {
|
||||
return r.db.Ping()
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection.
|
||||
func (r *TenantRegistry) Close() error {
|
||||
if r == nil || r.db == nil {
|
||||
return nil
|
||||
}
|
||||
return r.db.Close()
|
||||
}
|
||||
|
||||
// Create inserts a new tenant record.
|
||||
func (r *TenantRegistry) Create(t *Tenant) error {
|
||||
if t == nil {
|
||||
return fmt.Errorf("tenant is nil")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if t.CreatedAt.IsZero() {
|
||||
t.CreatedAt = now
|
||||
}
|
||||
t.UpdatedAt = now
|
||||
|
||||
_, err := r.db.Exec(`
|
||||
INSERT INTO tenants (
|
||||
id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
t.ID, t.Email, t.DisplayName, string(t.State),
|
||||
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
|
||||
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
|
||||
t.CreatedAt.Unix(), t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create tenant: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a tenant by ID.
|
||||
func (r *TenantRegistry) Get(id string) (*Tenant, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
FROM tenants WHERE id = ?`, id)
|
||||
return scanTenant(row)
|
||||
}
|
||||
|
||||
// GetByStripeCustomerID retrieves a tenant by Stripe customer ID.
|
||||
func (r *TenantRegistry) GetByStripeCustomerID(customerID string) (*Tenant, error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
FROM tenants WHERE stripe_customer_id = ?`, customerID)
|
||||
return scanTenant(row)
|
||||
}
|
||||
|
||||
// Update modifies an existing tenant record.
|
||||
func (r *TenantRegistry) Update(t *Tenant) error {
|
||||
if t == nil {
|
||||
return fmt.Errorf("tenant is nil")
|
||||
}
|
||||
t.UpdatedAt = time.Now().UTC()
|
||||
|
||||
res, err := r.db.Exec(`
|
||||
UPDATE tenants SET
|
||||
email = ?, display_name = ?, state = ?,
|
||||
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_price_id = ?,
|
||||
plan_version = ?, container_id = ?, current_image_digest = ?, desired_image_digest = ?,
|
||||
updated_at = ?, last_health_check = ?, health_check_ok = ?
|
||||
WHERE id = ?`,
|
||||
t.Email, t.DisplayName, string(t.State),
|
||||
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
|
||||
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
|
||||
t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
|
||||
t.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update tenant: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("tenant %q not found", t.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns all tenants.
|
||||
func (r *TenantRegistry) List() ([]*Tenant, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
FROM tenants ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tenants: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanTenants(rows)
|
||||
}
|
||||
|
||||
// ListByState returns all tenants matching the given state.
|
||||
func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
|
||||
rows, err := r.db.Query(`SELECT
|
||||
id, email, display_name, state,
|
||||
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
plan_version, container_id, current_image_digest, desired_image_digest,
|
||||
created_at, updated_at, last_health_check, health_check_ok
|
||||
FROM tenants WHERE state = ? ORDER BY created_at DESC`, string(state))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tenants by state: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanTenants(rows)
|
||||
}
|
||||
|
||||
// CountByState returns a map of state -> count.
|
||||
func (r *TenantRegistry) CountByState() (map[TenantState]int, error) {
|
||||
rows, err := r.db.Query(`SELECT state, COUNT(*) FROM tenants GROUP BY state`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count tenants by state: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
counts := make(map[TenantState]int)
|
||||
for rows.Next() {
|
||||
var state string
|
||||
var count int
|
||||
if err := rows.Scan(&state, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan count: %w", err)
|
||||
}
|
||||
counts[TenantState(state)] = count
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
// HealthSummary returns the number of healthy and unhealthy active tenants.
|
||||
func (r *TenantRegistry) HealthSummary() (healthy, unhealthy int, err error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
COALESCE(SUM(CASE WHEN health_check_ok = 1 THEN 1 ELSE 0 END), 0),
|
||||
COALESCE(SUM(CASE WHEN health_check_ok = 0 THEN 1 ELSE 0 END), 0)
|
||||
FROM tenants WHERE state = ?`, string(TenantStateActive))
|
||||
if err := row.Scan(&healthy, &unhealthy); err != nil {
|
||||
return 0, 0, fmt.Errorf("health summary: %w", err)
|
||||
}
|
||||
return healthy, unhealthy, nil
|
||||
}
|
||||
|
||||
// scanner is an interface satisfied by both *sql.Row and *sql.Rows.
|
||||
type scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanTenant(s scanner) (*Tenant, error) {
|
||||
var t Tenant
|
||||
var state string
|
||||
var createdAt, updatedAt int64
|
||||
var lastHealthCheck sql.NullInt64
|
||||
var healthOK int
|
||||
|
||||
err := s.Scan(
|
||||
&t.ID, &t.Email, &t.DisplayName, &state,
|
||||
&t.StripeCustomerID, &t.StripeSubscriptionID, &t.StripePriceID,
|
||||
&t.PlanVersion, &t.ContainerID, &t.CurrentImageDigest, &t.DesiredImageDigest,
|
||||
&createdAt, &updatedAt, &lastHealthCheck, &healthOK,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("scan tenant: %w", err)
|
||||
}
|
||||
|
||||
t.State = TenantState(state)
|
||||
t.CreatedAt = time.Unix(createdAt, 0).UTC()
|
||||
t.UpdatedAt = time.Unix(updatedAt, 0).UTC()
|
||||
if lastHealthCheck.Valid {
|
||||
ts := time.Unix(lastHealthCheck.Int64, 0).UTC()
|
||||
t.LastHealthCheck = &ts
|
||||
}
|
||||
t.HealthCheckOK = healthOK != 0
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func scanTenants(rows *sql.Rows) ([]*Tenant, error) {
|
||||
var tenants []*Tenant
|
||||
for rows.Next() {
|
||||
t, err := scanTenant(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tenants = append(tenants, t)
|
||||
}
|
||||
return tenants, rows.Err()
|
||||
}
|
||||
|
||||
func nullableTimeUnix(t *time.Time) any {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
282
internal/cloudcp/registry/registry_test.go
Normal file
282
internal/cloudcp/registry/registry_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestRegistry(t *testing.T) *TenantRegistry {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
reg, err := NewTenantRegistry(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTenantRegistry: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = reg.Close() })
|
||||
return reg
|
||||
}
|
||||
|
||||
func TestGenerateTenantID(t *testing.T) {
|
||||
id, err := GenerateTenantID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateTenantID: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(id, "t-") {
|
||||
t.Errorf("expected prefix t-, got %q", id)
|
||||
}
|
||||
if len(id) != 12 { // "t-" + 10 chars
|
||||
t.Errorf("expected length 12, got %d (%q)", len(id), id)
|
||||
}
|
||||
|
||||
// Uniqueness
|
||||
seen := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id, err := GenerateTenantID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if seen[id] {
|
||||
t.Fatalf("duplicate tenant ID: %s", id)
|
||||
}
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTenantID_CrockfordCharset(t *testing.T) {
|
||||
for i := 0; i < 50; i++ {
|
||||
id, err := GenerateTenantID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
suffix := id[2:] // strip "t-"
|
||||
for _, c := range suffix {
|
||||
if !strings.ContainsRune(crockfordBase32, c) {
|
||||
t.Errorf("character %q not in Crockford base32 alphabet (id=%s)", c, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCRUD(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
tenant := &Tenant{
|
||||
ID: "t-TEST00001",
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test Tenant",
|
||||
State: TenantStateProvisioning,
|
||||
StripeCustomerID: "cus_test123",
|
||||
PlanVersion: "stripe",
|
||||
}
|
||||
|
||||
// Create
|
||||
if err := reg.Create(tenant); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
if tenant.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should be set")
|
||||
}
|
||||
|
||||
// Get
|
||||
got, err := reg.Get("t-TEST00001")
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("Get returned nil")
|
||||
}
|
||||
if got.Email != "test@example.com" {
|
||||
t.Errorf("Email = %q, want %q", got.Email, "test@example.com")
|
||||
}
|
||||
if got.State != TenantStateProvisioning {
|
||||
t.Errorf("State = %q, want %q", got.State, TenantStateProvisioning)
|
||||
}
|
||||
|
||||
// Get not found
|
||||
notFound, err := reg.Get("t-NONEXIST1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get not found: %v", err)
|
||||
}
|
||||
if notFound != nil {
|
||||
t.Error("expected nil for non-existent tenant")
|
||||
}
|
||||
|
||||
// GetByStripeCustomerID
|
||||
got2, err := reg.GetByStripeCustomerID("cus_test123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByStripeCustomerID: %v", err)
|
||||
}
|
||||
if got2 == nil || got2.ID != "t-TEST00001" {
|
||||
t.Error("GetByStripeCustomerID should find the tenant")
|
||||
}
|
||||
|
||||
// Update
|
||||
got.State = TenantStateActive
|
||||
got.ContainerID = "abc123"
|
||||
if err := reg.Update(got); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
got3, err := reg.Get("t-TEST00001")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after update: %v", err)
|
||||
}
|
||||
if got3.State != TenantStateActive {
|
||||
t.Errorf("State after update = %q, want %q", got3.State, TenantStateActive)
|
||||
}
|
||||
if got3.ContainerID != "abc123" {
|
||||
t.Errorf("ContainerID = %q, want %q", got3.ContainerID, "abc123")
|
||||
}
|
||||
|
||||
// Update not found
|
||||
phantom := &Tenant{ID: "t-NONEXIST1"}
|
||||
if err := reg.Update(phantom); err == nil {
|
||||
t.Error("Update non-existent tenant should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
// Empty list
|
||||
tenants, err := reg.List()
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
if len(tenants) != 0 {
|
||||
t.Errorf("expected 0 tenants, got %d", len(tenants))
|
||||
}
|
||||
|
||||
// Add two tenants
|
||||
for _, id := range []string{"t-LIST00001", "t-LIST00002"} {
|
||||
if err := reg.Create(&Tenant{
|
||||
ID: id,
|
||||
Email: id + "@example.com",
|
||||
State: TenantStateActive,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
tenants, err = reg.List()
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
if len(tenants) != 2 {
|
||||
t.Errorf("expected 2 tenants, got %d", len(tenants))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByState(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
if err := reg.Create(&Tenant{ID: "t-STATE0001", State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{ID: "t-STATE0002", State: TenantStateSuspended}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{ID: "t-STATE0003", State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
active, err := reg.ListByState(TenantStateActive)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(active) != 2 {
|
||||
t.Errorf("expected 2 active, got %d", len(active))
|
||||
}
|
||||
|
||||
suspended, err := reg.ListByState(TenantStateSuspended)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(suspended) != 1 {
|
||||
t.Errorf("expected 1 suspended, got %d", len(suspended))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountByState(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
if err := reg.Create(&Tenant{ID: "t-CNT000001", State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{ID: "t-CNT000002", State: TenantStateActive}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{ID: "t-CNT000003", State: TenantStateCanceled}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
counts, err := reg.CountByState()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if counts[TenantStateActive] != 2 {
|
||||
t.Errorf("active count = %d, want 2", counts[TenantStateActive])
|
||||
}
|
||||
if counts[TenantStateCanceled] != 1 {
|
||||
t.Errorf("canceled count = %d, want 1", counts[TenantStateCanceled])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthSummary(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
if err := reg.Create(&Tenant{
|
||||
ID: "t-HLTH00001", State: TenantStateActive,
|
||||
HealthCheckOK: true, LastHealthCheck: &now,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := reg.Create(&Tenant{
|
||||
ID: "t-HLTH00002", State: TenantStateActive,
|
||||
HealthCheckOK: false, LastHealthCheck: &now,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Suspended tenant should not count
|
||||
if err := reg.Create(&Tenant{
|
||||
ID: "t-HLTH00003", State: TenantStateSuspended,
|
||||
HealthCheckOK: false,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
healthy, unhealthy, err := reg.HealthSummary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if healthy != 1 {
|
||||
t.Errorf("healthy = %d, want 1", healthy)
|
||||
}
|
||||
if unhealthy != 1 {
|
||||
t.Errorf("unhealthy = %d, want 1", unhealthy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
reg := newTestRegistry(t)
|
||||
if err := reg.Ping(); err != nil {
|
||||
t.Errorf("Ping: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTenantRegistry_InvalidDir(t *testing.T) {
|
||||
// Read-only path that doesn't exist
|
||||
_, err := NewTenantRegistry("/proc/nonexistent/path")
|
||||
if err == nil {
|
||||
// On macOS /proc doesn't exist, so MkdirAll will fail
|
||||
// On Linux with /proc it would also fail
|
||||
// But skip if somehow it works (unlikely)
|
||||
if _, statErr := os.Stat("/proc/nonexistent/path"); statErr != nil {
|
||||
t.Log("Skipping: path creation succeeded unexpectedly")
|
||||
}
|
||||
}
|
||||
}
|
||||
35
internal/cloudcp/routes.go
Normal file
35
internal/cloudcp/routes.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cloudcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/admin"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/docker"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
cpstripe "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/stripe"
|
||||
)
|
||||
|
||||
// Deps holds shared dependencies injected into HTTP handlers.
|
||||
type Deps struct {
|
||||
Config *CPConfig
|
||||
Registry *registry.TenantRegistry
|
||||
Docker *docker.Manager // nil if Docker is unavailable
|
||||
Version string
|
||||
}
|
||||
|
||||
// RegisterRoutes wires all HTTP handlers onto the given ServeMux.
|
||||
func RegisterRoutes(mux *http.ServeMux, deps *Deps) {
|
||||
// Health / readiness / status (unauthenticated)
|
||||
mux.HandleFunc("/healthz", admin.HandleHealthz)
|
||||
mux.HandleFunc("/readyz", admin.HandleReadyz(deps.Registry))
|
||||
mux.HandleFunc("/status", admin.HandleStatus(deps.Registry, deps.Version))
|
||||
|
||||
// Stripe webhook (signature-authenticated)
|
||||
provisioner := cpstripe.NewProvisioner(deps.Registry, deps.Config.TenantsDir())
|
||||
webhookHandler := cpstripe.NewWebhookHandler(deps.Config.StripeWebhookSecret, provisioner)
|
||||
mux.Handle("/api/stripe/webhook", webhookHandler)
|
||||
|
||||
// Admin API (key-authenticated)
|
||||
tenantsHandler := admin.HandleListTenants(deps.Registry)
|
||||
mux.Handle("/admin/tenants", admin.AdminKeyMiddleware(deps.Config.AdminKey, tenantsHandler))
|
||||
}
|
||||
146
internal/cloudcp/server.go
Normal file
146
internal/cloudcp/server.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package cloudcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
cpDocker "github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/docker"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/health"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Run starts the control plane HTTP server with graceful shutdown.
|
||||
func Run(ctx context.Context, version string) error {
|
||||
logging.Init(logging.Config{
|
||||
Format: "auto",
|
||||
Level: "info",
|
||||
Component: "control-plane",
|
||||
})
|
||||
|
||||
log.Info().Str("version", version).Msg("Starting Pulse Cloud Control Plane")
|
||||
|
||||
cfg, err := LoadConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
// Ensure data directories exist
|
||||
if err := os.MkdirAll(cfg.TenantsDir(), 0o755); err != nil {
|
||||
return fmt.Errorf("create tenants dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(cfg.ControlPlaneDir(), 0o755); err != nil {
|
||||
return fmt.Errorf("create control-plane dir: %w", err)
|
||||
}
|
||||
|
||||
// Open tenant registry
|
||||
reg, err := registry.NewTenantRegistry(cfg.ControlPlaneDir())
|
||||
if err != nil {
|
||||
return fmt.Errorf("open tenant registry: %w", err)
|
||||
}
|
||||
defer reg.Close()
|
||||
|
||||
// Initialize Docker manager (best-effort — control plane can run without Docker for dev/testing)
|
||||
var dockerMgr *cpDocker.Manager
|
||||
dockerMgr, err = cpDocker.NewManager(cpDocker.ManagerConfig{
|
||||
Image: cfg.PulseImage,
|
||||
Network: cfg.DockerNetwork,
|
||||
BaseDomain: baseDomainFromURL(cfg.BaseURL),
|
||||
MemoryLimit: cfg.TenantMemoryLimit,
|
||||
CPUShares: cfg.TenantCPUShares,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Docker unavailable — container management disabled")
|
||||
dockerMgr = nil
|
||||
} else {
|
||||
defer dockerMgr.Close()
|
||||
}
|
||||
|
||||
// Build HTTP routes
|
||||
mux := http.NewServeMux()
|
||||
deps := &Deps{
|
||||
Config: cfg,
|
||||
Registry: reg,
|
||||
Docker: dockerMgr,
|
||||
Version: version,
|
||||
}
|
||||
RegisterRoutes(mux, deps)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.Port)
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 15 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
// Create derived context for background goroutines
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Start health monitor if Docker is available
|
||||
if dockerMgr != nil {
|
||||
monitor := health.NewMonitor(reg, dockerMgr, health.MonitorConfig{
|
||||
Interval: 60 * time.Second,
|
||||
RestartOnFail: true,
|
||||
FailThreshold: 3,
|
||||
})
|
||||
go monitor.Run(ctx)
|
||||
}
|
||||
|
||||
// Start server in background
|
||||
go func() {
|
||||
log.Info().Str("addr", addr).Msg("Control plane listening")
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("Server failed")
|
||||
}
|
||||
}()
|
||||
|
||||
// Signal handling
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("Context cancelled, shutting down...")
|
||||
case sig := <-sigChan:
|
||||
log.Info().Str("signal", sig.String()).Msg("Received signal, shutting down...")
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Error().Err(err).Msg("Server shutdown error")
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Info().Msg("Control plane stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// baseDomainFromURL extracts a base domain from a URL like "https://cloud.pulserelay.pro".
|
||||
func baseDomainFromURL(baseURL string) string {
|
||||
// Strip scheme
|
||||
domain := baseURL
|
||||
for _, prefix := range []string{"https://", "http://"} {
|
||||
if len(domain) > len(prefix) && domain[:len(prefix)] == prefix {
|
||||
domain = domain[len(prefix):]
|
||||
break
|
||||
}
|
||||
}
|
||||
// Strip port and path
|
||||
for i := 0; i < len(domain); i++ {
|
||||
if domain[i] == ':' || domain[i] == '/' {
|
||||
domain = domain[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return domain
|
||||
}
|
||||
72
internal/cloudcp/stripe/helpers.go
Normal file
72
internal/cloudcp/stripe/helpers.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package stripe
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license/entitlements"
|
||||
)
|
||||
|
||||
// MapSubscriptionStatus converts a Stripe subscription status string to the
|
||||
// internal SubscriptionState. Unknown statuses fail closed (expired).
|
||||
func MapSubscriptionStatus(status string) entitlements.SubscriptionState {
|
||||
switch strings.ToLower(strings.TrimSpace(status)) {
|
||||
case "active":
|
||||
return entitlements.SubStateActive
|
||||
case "trialing":
|
||||
return entitlements.SubStateTrial
|
||||
case "past_due", "unpaid":
|
||||
return entitlements.SubStateGrace
|
||||
case "canceled":
|
||||
return entitlements.SubStateCanceled
|
||||
case "paused":
|
||||
return entitlements.SubStateSuspended
|
||||
case "incomplete", "incomplete_expired":
|
||||
return entitlements.SubStateExpired
|
||||
default:
|
||||
return entitlements.SubStateExpired
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldGrantCapabilities returns true if the subscription state warrants
|
||||
// granting paid capabilities.
|
||||
func ShouldGrantCapabilities(state entitlements.SubscriptionState) bool {
|
||||
switch state {
|
||||
case entitlements.SubStateActive, entitlements.SubStateTrial, entitlements.SubStateGrace:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DerivePlanVersion extracts a plan version from event metadata, falling back
|
||||
// to a Stripe price ID prefix or a generic "stripe" string.
|
||||
func DerivePlanVersion(metadata map[string]string, priceID string) string {
|
||||
if metadata != nil {
|
||||
if v := strings.TrimSpace(metadata["plan_version"]); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(metadata["plan"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(priceID) != "" {
|
||||
return "stripe_price:" + strings.TrimSpace(priceID)
|
||||
}
|
||||
return "stripe"
|
||||
}
|
||||
|
||||
// IsSafeStripeID validates that a Stripe ID (cus_..., sub_...) is safe for
|
||||
// use as a lookup key. Keeps the check strict to avoid filesystem surprises.
|
||||
func IsSafeStripeID(id string) bool {
|
||||
if len(id) < 5 || len(id) > 128 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(id); i++ {
|
||||
c := id[i]
|
||||
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
107
internal/cloudcp/stripe/helpers_test.go
Normal file
107
internal/cloudcp/stripe/helpers_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package stripe
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license/entitlements"
|
||||
)
|
||||
|
||||
func TestMapSubscriptionStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want entitlements.SubscriptionState
|
||||
}{
|
||||
{"active", entitlements.SubStateActive},
|
||||
{"Active", entitlements.SubStateActive},
|
||||
{"trialing", entitlements.SubStateTrial},
|
||||
{"past_due", entitlements.SubStateGrace},
|
||||
{"unpaid", entitlements.SubStateGrace},
|
||||
{"canceled", entitlements.SubStateCanceled},
|
||||
{"paused", entitlements.SubStateSuspended},
|
||||
{"incomplete", entitlements.SubStateExpired},
|
||||
{"incomplete_expired", entitlements.SubStateExpired},
|
||||
{"unknown_status", entitlements.SubStateExpired},
|
||||
{"", entitlements.SubStateExpired},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := MapSubscriptionStatus(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapSubscriptionStatus(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldGrantCapabilities(t *testing.T) {
|
||||
tests := []struct {
|
||||
state entitlements.SubscriptionState
|
||||
want bool
|
||||
}{
|
||||
{entitlements.SubStateActive, true},
|
||||
{entitlements.SubStateTrial, true},
|
||||
{entitlements.SubStateGrace, true},
|
||||
{entitlements.SubStateCanceled, false},
|
||||
{entitlements.SubStateSuspended, false},
|
||||
{entitlements.SubStateExpired, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.state), func(t *testing.T) {
|
||||
got := ShouldGrantCapabilities(tt.state)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldGrantCapabilities(%q) = %v, want %v", tt.state, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDerivePlanVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]string
|
||||
priceID string
|
||||
want string
|
||||
}{
|
||||
{"plan_version in metadata", map[string]string{"plan_version": "v2"}, "", "v2"},
|
||||
{"plan in metadata", map[string]string{"plan": "pro"}, "", "pro"},
|
||||
{"plan_version takes priority", map[string]string{"plan_version": "v3", "plan": "pro"}, "", "v3"},
|
||||
{"price ID fallback", nil, "price_123", "stripe_price:price_123"},
|
||||
{"generic fallback", nil, "", "stripe"},
|
||||
{"nil metadata with price", nil, "price_abc", "stripe_price:price_abc"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DerivePlanVersion(tt.metadata, tt.priceID)
|
||||
if got != tt.want {
|
||||
t.Errorf("DerivePlanVersion = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeStripeID(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
want bool
|
||||
}{
|
||||
{"cus_test123", true},
|
||||
{"sub_abc-def", true},
|
||||
{"evt_12345678901234567890", true},
|
||||
{"", false},
|
||||
{"ab", false},
|
||||
{"cus_../etc/passwd", false},
|
||||
{"cus test", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.id, func(t *testing.T) {
|
||||
got := IsSafeStripeID(tt.id)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsSafeStripeID(%q) = %v, want %v", tt.id, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
215
internal/cloudcp/stripe/provisioner.go
Normal file
215
internal/cloudcp/stripe/provisioner.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package stripe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/cloudcp/registry"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license/entitlements"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Provisioner orchestrates tenant creation, billing state updates, and (later)
|
||||
// container lifecycle in response to Stripe events.
|
||||
type Provisioner struct {
|
||||
registry *registry.TenantRegistry
|
||||
tenantsDir string
|
||||
}
|
||||
|
||||
// NewProvisioner creates a Provisioner.
|
||||
func NewProvisioner(reg *registry.TenantRegistry, tenantsDir string) *Provisioner {
|
||||
return &Provisioner{
|
||||
registry: reg,
|
||||
tenantsDir: tenantsDir,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleCheckout provisions a new tenant from a checkout.session.completed event.
|
||||
func (p *Provisioner) HandleCheckout(ctx context.Context, session CheckoutSession) error {
|
||||
customerID := strings.TrimSpace(session.Customer)
|
||||
if customerID == "" {
|
||||
return fmt.Errorf("checkout session missing customer")
|
||||
}
|
||||
if !IsSafeStripeID(customerID) {
|
||||
return fmt.Errorf("invalid stripe customer id: %s", customerID)
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(session.CustomerEmail))
|
||||
if email == "" {
|
||||
email = strings.ToLower(strings.TrimSpace(session.CustomerDetails.Email))
|
||||
}
|
||||
|
||||
// Check if a tenant already exists for this Stripe customer
|
||||
existing, err := p.registry.GetByStripeCustomerID(customerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup existing tenant: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
log.Info().
|
||||
Str("tenant_id", existing.ID).
|
||||
Str("customer_id", customerID).
|
||||
Msg("Tenant already exists for Stripe customer, skipping provisioning")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate tenant ID
|
||||
tenantID, err := registry.GenerateTenantID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate tenant id: %w", err)
|
||||
}
|
||||
|
||||
planVersion := DerivePlanVersion(session.Metadata, "")
|
||||
|
||||
// Write billing.json to tenant data dir.
|
||||
// FileBillingStore with baseDataDir=<tenantDir> + SaveBillingState("default", state)
|
||||
// writes billing.json at the root of the tenant data dir.
|
||||
tenantDataDir := p.tenantsDir + "/" + tenantID
|
||||
billingStore := config.NewFileBillingStore(tenantDataDir)
|
||||
state := &entitlements.BillingState{
|
||||
Capabilities: license.DeriveCapabilitiesFromTier(license.TierCloud, nil),
|
||||
Limits: map[string]int64{},
|
||||
MetersEnabled: []string{},
|
||||
PlanVersion: planVersion,
|
||||
SubscriptionState: entitlements.SubStateActive,
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: strings.TrimSpace(session.Subscription),
|
||||
}
|
||||
if err := billingStore.SaveBillingState("default", state); err != nil {
|
||||
return fmt.Errorf("write billing state: %w", err)
|
||||
}
|
||||
|
||||
// Insert registry record
|
||||
tenant := ®istry.Tenant{
|
||||
ID: tenantID,
|
||||
Email: email,
|
||||
State: registry.TenantStateProvisioning,
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: strings.TrimSpace(session.Subscription),
|
||||
PlanVersion: planVersion,
|
||||
}
|
||||
if err := p.registry.Create(tenant); err != nil {
|
||||
return fmt.Errorf("create tenant record: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("tenant_id", tenantID).
|
||||
Str("customer_id", customerID).
|
||||
Str("email", email).
|
||||
Str("plan_version", planVersion).
|
||||
Msg("Tenant provisioned from checkout")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleSubscriptionUpdated syncs billing state when a subscription changes.
|
||||
func (p *Provisioner) HandleSubscriptionUpdated(ctx context.Context, sub Subscription) error {
|
||||
customerID := strings.TrimSpace(sub.Customer)
|
||||
if customerID == "" {
|
||||
return fmt.Errorf("subscription missing customer")
|
||||
}
|
||||
|
||||
tenant, err := p.registry.GetByStripeCustomerID(customerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup tenant by customer: %w", err)
|
||||
}
|
||||
if tenant == nil {
|
||||
log.Warn().Str("customer_id", customerID).Msg("subscription.updated: tenant not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
subState := MapSubscriptionStatus(sub.Status)
|
||||
priceID := sub.FirstPriceID()
|
||||
planVersion := DerivePlanVersion(sub.Metadata, priceID)
|
||||
|
||||
// Update billing.json
|
||||
var caps []string
|
||||
if ShouldGrantCapabilities(subState) {
|
||||
caps = license.DeriveCapabilitiesFromTier(license.TierCloud, nil)
|
||||
}
|
||||
|
||||
tenantDataDir := p.tenantsDir + "/" + tenant.ID
|
||||
billingStore := config.NewFileBillingStore(tenantDataDir)
|
||||
state := &entitlements.BillingState{
|
||||
Capabilities: caps,
|
||||
Limits: map[string]int64{},
|
||||
MetersEnabled: []string{},
|
||||
PlanVersion: planVersion,
|
||||
SubscriptionState: subState,
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: strings.TrimSpace(sub.ID),
|
||||
StripePriceID: priceID,
|
||||
}
|
||||
if err := billingStore.SaveBillingState("default", state); err != nil {
|
||||
return fmt.Errorf("save billing state: %w", err)
|
||||
}
|
||||
|
||||
// Update registry
|
||||
tenant.StripeSubscriptionID = strings.TrimSpace(sub.ID)
|
||||
tenant.StripePriceID = priceID
|
||||
tenant.PlanVersion = planVersion
|
||||
if subState == entitlements.SubStateSuspended {
|
||||
tenant.State = registry.TenantStateSuspended
|
||||
} else if subState == entitlements.SubStateActive || subState == entitlements.SubStateTrial || subState == entitlements.SubStateGrace {
|
||||
tenant.State = registry.TenantStateActive
|
||||
}
|
||||
if err := p.registry.Update(tenant); err != nil {
|
||||
return fmt.Errorf("update tenant record: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("tenant_id", tenant.ID).
|
||||
Str("customer_id", customerID).
|
||||
Str("subscription_state", string(subState)).
|
||||
Msg("Subscription updated")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleSubscriptionDeleted revokes capabilities on cancellation.
|
||||
func (p *Provisioner) HandleSubscriptionDeleted(ctx context.Context, sub Subscription) error {
|
||||
customerID := strings.TrimSpace(sub.Customer)
|
||||
if customerID == "" {
|
||||
return fmt.Errorf("subscription missing customer")
|
||||
}
|
||||
|
||||
tenant, err := p.registry.GetByStripeCustomerID(customerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup tenant by customer: %w", err)
|
||||
}
|
||||
if tenant == nil {
|
||||
log.Warn().Str("customer_id", customerID).Msg("subscription.deleted: tenant not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke capabilities immediately
|
||||
tenantDataDir := p.tenantsDir + "/" + tenant.ID
|
||||
billingStore := config.NewFileBillingStore(tenantDataDir)
|
||||
state := &entitlements.BillingState{
|
||||
Capabilities: []string{},
|
||||
Limits: map[string]int64{},
|
||||
MetersEnabled: []string{},
|
||||
PlanVersion: tenant.PlanVersion,
|
||||
SubscriptionState: entitlements.SubStateCanceled,
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: strings.TrimSpace(sub.ID),
|
||||
}
|
||||
if err := billingStore.SaveBillingState("default", state); err != nil {
|
||||
return fmt.Errorf("save billing state: %w", err)
|
||||
}
|
||||
|
||||
// Update registry
|
||||
tenant.State = registry.TenantStateCanceled
|
||||
if err := p.registry.Update(tenant); err != nil {
|
||||
return fmt.Errorf("update tenant record: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("tenant_id", tenant.ID).
|
||||
Str("customer_id", customerID).
|
||||
Msg("Subscription deleted, capabilities revoked")
|
||||
|
||||
return nil
|
||||
}
|
||||
163
internal/cloudcp/stripe/webhook.go
Normal file
163
internal/cloudcp/stripe/webhook.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package stripe
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
stripelib "github.com/stripe/stripe-go/v82"
|
||||
"github.com/stripe/stripe-go/v82/webhook"
|
||||
)
|
||||
|
||||
const webhookBodyLimit = 1024 * 1024 // 1 MiB
|
||||
|
||||
// WebhookHandler handles incoming Stripe webhook events.
|
||||
type WebhookHandler struct {
|
||||
secret string
|
||||
provisioner *Provisioner
|
||||
}
|
||||
|
||||
// NewWebhookHandler creates a Stripe webhook HTTP handler.
|
||||
func NewWebhookHandler(secret string, provisioner *Provisioner) *WebhookHandler {
|
||||
return &WebhookHandler{
|
||||
secret: secret,
|
||||
provisioner: provisioner,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP verifies the Stripe signature and dispatches the event.
|
||||
func (h *WebhookHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(h.secret) == "" {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]string{
|
||||
"error": "webhook secret not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, webhookBodyLimit)
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{
|
||||
"error": "failed to read request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sigHeader := r.Header.Get("Stripe-Signature")
|
||||
if strings.TrimSpace(sigHeader) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{
|
||||
"error": "missing Stripe signature",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
event, err := webhook.ConstructEventWithOptions(payload, sigHeader, h.secret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{
|
||||
"error": "invalid Stripe signature",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.handleEvent(r, &event); err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("event_id", event.ID).
|
||||
Str("type", string(event.Type)).
|
||||
Msg("Stripe webhook processing failed")
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]string{
|
||||
"error": "processing failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"received": true,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *WebhookHandler) handleEvent(r *http.Request, event *stripelib.Event) error {
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
var session CheckoutSession
|
||||
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
|
||||
return fmt.Errorf("decode checkout.session: %w", err)
|
||||
}
|
||||
return h.provisioner.HandleCheckout(r.Context(), session)
|
||||
|
||||
case "customer.subscription.updated":
|
||||
var sub Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
return fmt.Errorf("decode subscription: %w", err)
|
||||
}
|
||||
return h.provisioner.HandleSubscriptionUpdated(r.Context(), sub)
|
||||
|
||||
case "customer.subscription.deleted":
|
||||
var sub Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
return fmt.Errorf("decode subscription: %w", err)
|
||||
}
|
||||
return h.provisioner.HandleSubscriptionDeleted(r.Context(), sub)
|
||||
|
||||
default:
|
||||
log.Info().
|
||||
Str("type", string(event.Type)).
|
||||
Str("event_id", event.ID).
|
||||
Msg("Stripe webhook ignored (unhandled type)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CheckoutSession is a minimal representation of a Stripe checkout.session event.
|
||||
type CheckoutSession struct {
|
||||
ID string `json:"id"`
|
||||
Mode string `json:"mode"`
|
||||
Customer string `json:"customer"`
|
||||
Subscription string `json:"subscription"`
|
||||
CustomerEmail string `json:"customer_email"`
|
||||
CustomerDetails struct {
|
||||
Email string `json:"email"`
|
||||
} `json:"customer_details"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// Subscription is a minimal representation of a Stripe subscription event.
|
||||
type Subscription struct {
|
||||
ID string `json:"id"`
|
||||
Customer string `json:"customer"`
|
||||
Status string `json:"status"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
Items struct {
|
||||
Data []struct {
|
||||
Price struct {
|
||||
ID string `json:"id"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"price"`
|
||||
} `json:"data"`
|
||||
} `json:"items"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// FirstPriceID returns the price ID from the first subscription item.
|
||||
func (s *Subscription) FirstPriceID() string {
|
||||
for _, item := range s.Items.Data {
|
||||
if id := strings.TrimSpace(item.Price.ID); id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
Reference in New Issue
Block a user