mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
feat: add multi-tenant isolation foundation (disabled by default)
Implements multi-tenant infrastructure for organization-based data isolation. Feature is gated behind PULSE_MULTI_TENANT_ENABLED env var and requires Enterprise license - no impact on existing users. Core components: - TenantMiddleware: extracts org ID, validates access, 501/402 responses - AuthorizationChecker: token/user access validation for organizations - MultiTenantChecker: WebSocket upgrade gating with license check - Per-tenant audit logging via LogAuditEventForTenant - Organization model with membership support Gating behavior: - Feature flag disabled: 501 Not Implemented for non-default orgs - Flag enabled, no license: 402 Payment Required - Default org always works regardless of flag/license Documentation added: docs/MULTI_TENANT.md
This commit is contained in:
251
docs/MULTI_TENANT.md
Normal file
251
docs/MULTI_TENANT.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# Multi-Tenant Feature Documentation
|
||||
|
||||
## Status: Disabled by Default
|
||||
|
||||
This feature is gated behind a feature flag and license check. It will not affect existing users unless explicitly enabled.
|
||||
|
||||
---
|
||||
|
||||
## How to Enable
|
||||
|
||||
### Requirements
|
||||
|
||||
1. **Feature flag**: Set environment variable
|
||||
```bash
|
||||
PULSE_MULTI_TENANT_ENABLED=true
|
||||
```
|
||||
|
||||
2. **License**: Enterprise license with `multi_tenant` feature enabled
|
||||
|
||||
### Behavior Without Enablement
|
||||
|
||||
| Condition | HTTP Response | WebSocket Response |
|
||||
|-----------|---------------|-------------------|
|
||||
| Feature flag disabled | 501 Not Implemented | 501 Not Implemented |
|
||||
| Flag enabled, no license | 402 Payment Required | 402 Payment Required |
|
||||
| Flag enabled + licensed | Normal operation | Normal operation |
|
||||
|
||||
The "default" organization always works regardless of feature flag or license status.
|
||||
|
||||
---
|
||||
|
||||
## What's Implemented
|
||||
|
||||
### Tenant Isolation
|
||||
|
||||
| Component | Status | Details |
|
||||
|-----------|--------|---------|
|
||||
| State/Monitor | ✅ | Each org gets its own `Monitor` instance via `MultiTenantMonitor` |
|
||||
| WebSocket | ✅ | Clients bound to tenant, broadcasts filtered by org |
|
||||
| Audit Logs | ✅ | `LogAuditEventForTenant()` writes to per-org audit DB |
|
||||
| Resources | ✅ | Per-tenant resource stores with `PopulateFromSnapshotForTenant()` |
|
||||
| Persistence | ✅ | `MultiTenantPersistence` provides per-org config directories |
|
||||
|
||||
### Gating & Authorization
|
||||
|
||||
| Component | Status | Details |
|
||||
|-----------|--------|---------|
|
||||
| Feature flag | ✅ | `PULSE_MULTI_TENANT_ENABLED` env var (default: false) |
|
||||
| License check | ✅ | Requires `multi_tenant` feature in Enterprise license |
|
||||
| HTTP middleware | ✅ | `TenantMiddleware` extracts org ID, validates access |
|
||||
| WebSocket gating | ✅ | `MultiTenantChecker` validates before upgrade |
|
||||
| Token authorization | ✅ | `AuthorizationChecker.TokenCanAccessOrg()` |
|
||||
| User authorization | ✅ | `AuthorizationChecker.UserCanAccessOrg()` via org membership |
|
||||
|
||||
### Tenant-Aware Endpoints
|
||||
|
||||
All user-facing data endpoints use `getTenantMonitor(ctx)`:
|
||||
|
||||
- `/api/state`
|
||||
- `/api/charts`
|
||||
- `/api/storage/{id}`
|
||||
- `/api/backups`, `/api/backups/pve`, `/api/backups/pbs`
|
||||
- `/api/snapshots`
|
||||
- `/api/resources/*`
|
||||
- `/api/metrics/*`
|
||||
|
||||
---
|
||||
|
||||
## Intentionally Global (Admin-Level)
|
||||
|
||||
These endpoints show system-wide data regardless of tenant context:
|
||||
|
||||
| Endpoint | Rationale |
|
||||
|----------|-----------|
|
||||
| `/api/health` | System uptime, not tenant-specific |
|
||||
| `/api/scheduler/health` | Process-level scheduler status |
|
||||
| `/api/diagnostics/*` | Admin diagnostics for full system |
|
||||
|
||||
Also global:
|
||||
- `security_setup_fix.go` - Clears unauthenticated agents on default monitor
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### Key Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `internal/api/middleware_tenant.go` | Extracts org ID, validates access, injects context |
|
||||
| `internal/api/middleware_license.go` | Feature flag, license check, 501/402 responses |
|
||||
| `internal/api/authorization.go` | `AuthorizationChecker` interface, token/user access checks |
|
||||
| `internal/monitoring/multi_tenant_monitor.go` | Per-org monitor instances |
|
||||
| `internal/config/multi_tenant.go` | Per-org persistence (config directories) |
|
||||
| `internal/websocket/hub.go` | Tenant-aware client tracking, `MultiTenantChecker` |
|
||||
| `pkg/server/server.go` | Wires up org loader, multi-tenant checker |
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Request
|
||||
│
|
||||
├─► TenantMiddleware
|
||||
│ ├─► Extract org ID (header/cookie/default)
|
||||
│ ├─► Feature flag check (501 if disabled)
|
||||
│ ├─► License check (402 if unlicensed)
|
||||
│ ├─► Authorization check (403 if denied)
|
||||
│ └─► Inject org ID into context
|
||||
│
|
||||
├─► Handler
|
||||
│ └─► getTenantMonitor(ctx) → org-specific Monitor
|
||||
│
|
||||
└─► Response (org-scoped data)
|
||||
```
|
||||
|
||||
### Org ID Sources (Priority Order)
|
||||
|
||||
1. `X-Pulse-Org-ID` header (API clients/agents)
|
||||
2. `pulse_org_id` cookie (browser sessions)
|
||||
3. Fallback: `"default"`
|
||||
|
||||
---
|
||||
|
||||
## Data Model
|
||||
|
||||
### Organization
|
||||
|
||||
```go
|
||||
type Organization struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
OwnerUserID string // Creator/owner
|
||||
Members []OrganizationMember // User membership
|
||||
}
|
||||
|
||||
type OrganizationMember struct {
|
||||
UserID string
|
||||
Role string // "owner", "admin", "member"
|
||||
AddedAt time.Time
|
||||
AddedBy string
|
||||
}
|
||||
```
|
||||
|
||||
### API Token Binding
|
||||
|
||||
```go
|
||||
type APITokenRecord struct {
|
||||
// ... existing fields ...
|
||||
OrgID string // Single org binding
|
||||
OrgIDs []string // Multi-org access (MSP tokens)
|
||||
}
|
||||
```
|
||||
|
||||
Legacy tokens (empty `OrgID`) have wildcard access during migration period.
|
||||
|
||||
---
|
||||
|
||||
## TODO / Deferred Items
|
||||
|
||||
### High Priority (Before GA)
|
||||
|
||||
- [ ] **Config deep copy**: `multi_tenant_monitor.go:59` does shallow copy; credential slices may be shared
|
||||
- [ ] **Migration script**: Move existing data to `/orgs/default/` with symlinks for backward compatibility
|
||||
- [ ] **UI integration**: Org switcher, org management screens
|
||||
|
||||
### Medium Priority
|
||||
|
||||
- [ ] **Per-tenant node credentials**: Load tenant-specific `nodes.enc` instead of inheriting base config
|
||||
- [ ] **Org CRUD endpoints**: Create/update/delete organizations via API
|
||||
- [ ] **Member management**: Add/remove users from organizations
|
||||
|
||||
### Low Priority / Policy Decisions
|
||||
|
||||
- [ ] Decide if diagnostics should be org-scoped or super-admin only
|
||||
- [ ] Decide if `security_setup_fix.go` agent cleanup should be org-scoped
|
||||
|
||||
---
|
||||
|
||||
## Testing Checklist
|
||||
|
||||
### Unit Tests
|
||||
|
||||
```bash
|
||||
# Tenant middleware tests
|
||||
go test ./internal/api -run TestTenantMiddleware
|
||||
|
||||
# WebSocket multi-tenant tests
|
||||
go test ./internal/websocket -run TestHandleWebSocket_MultiTenant
|
||||
```
|
||||
|
||||
### Manual Testing
|
||||
|
||||
1. **Default behavior (flag disabled)**
|
||||
- Start Pulse without `PULSE_MULTI_TENANT_ENABLED`
|
||||
- Verify normal operation
|
||||
- Attempt `X-Pulse-Org-ID: test-org` header → expect 501
|
||||
|
||||
2. **Flag enabled, no license**
|
||||
- Set `PULSE_MULTI_TENANT_ENABLED=true`
|
||||
- No Enterprise license
|
||||
- Attempt non-default org → expect 402
|
||||
|
||||
3. **Full multi-tenant**
|
||||
- Enable flag + Enterprise license
|
||||
- Create org "test-a" with PVE node A
|
||||
- Create org "test-b" with PVE node B
|
||||
- Open browser tabs for each org
|
||||
- Verify each sees only their nodes
|
||||
- Verify WebSocket updates are isolated
|
||||
- Attempt header spoofing with wrong token → expect 403
|
||||
|
||||
### Integration Test Script
|
||||
|
||||
```bash
|
||||
# 1. Verify default org works without flag
|
||||
curl -u admin:admin http://localhost:7655/api/state
|
||||
# → 200 OK
|
||||
|
||||
# 2. Verify non-default org blocked without flag
|
||||
curl -u admin:admin -H "X-Pulse-Org-ID: test-org" http://localhost:7655/api/state
|
||||
# → 501 Not Implemented
|
||||
|
||||
# 3. With flag enabled but no license
|
||||
export PULSE_MULTI_TENANT_ENABLED=true
|
||||
curl -u admin:admin -H "X-Pulse-Org-ID: test-org" http://localhost:7655/api/state
|
||||
# → 402 Payment Required
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Response Codes Reference
|
||||
|
||||
| Code | Meaning | When |
|
||||
|------|---------|------|
|
||||
| 200 | Success | Valid org access |
|
||||
| 400 | Bad Request | Invalid org ID format |
|
||||
| 402 | Payment Required | Feature enabled but not licensed |
|
||||
| 403 | Forbidden | Token/user not authorized for org |
|
||||
| 501 | Not Implemented | Feature flag disabled |
|
||||
|
||||
---
|
||||
|
||||
## Changelog
|
||||
|
||||
- **2024-01**: Initial implementation
|
||||
- Feature flag gating
|
||||
- License enforcement
|
||||
- Per-tenant state isolation
|
||||
- WebSocket tenant binding
|
||||
- Audit log isolation
|
||||
- Authorization framework
|
||||
206
internal/api/authorization.go
Normal file
206
internal/api/authorization.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AuthorizationChecker provides methods to check if a user or token can access an organization.
|
||||
type AuthorizationChecker interface {
|
||||
// TokenCanAccessOrg checks if an API token is authorized to access the specified organization.
|
||||
TokenCanAccessOrg(token *config.APITokenRecord, orgID string) bool
|
||||
|
||||
// UserCanAccessOrg checks if a user is a member of the specified organization.
|
||||
UserCanAccessOrg(userID, orgID string) bool
|
||||
|
||||
// CheckAccess performs a comprehensive authorization check for a request.
|
||||
CheckAccess(token *config.APITokenRecord, userID, orgID string) AuthorizationResult
|
||||
}
|
||||
|
||||
// DefaultAuthorizationChecker implements AuthorizationChecker with the default logic.
|
||||
type DefaultAuthorizationChecker struct {
|
||||
// orgLoader is used to load organization data for membership checks.
|
||||
orgLoader OrganizationLoader
|
||||
}
|
||||
|
||||
// OrganizationLoader provides methods to load organization data.
|
||||
type OrganizationLoader interface {
|
||||
// GetOrganization returns the organization with the specified ID.
|
||||
GetOrganization(orgID string) (*models.Organization, error)
|
||||
}
|
||||
|
||||
// NewAuthorizationChecker creates a new DefaultAuthorizationChecker.
|
||||
func NewAuthorizationChecker(loader OrganizationLoader) *DefaultAuthorizationChecker {
|
||||
return &DefaultAuthorizationChecker{
|
||||
orgLoader: loader,
|
||||
}
|
||||
}
|
||||
|
||||
// MultiTenantOrganizationLoader implements OrganizationLoader using MultiTenantPersistence.
|
||||
type MultiTenantOrganizationLoader struct {
|
||||
persistence *config.MultiTenantPersistence
|
||||
}
|
||||
|
||||
// NewMultiTenantOrganizationLoader creates a new organization loader.
|
||||
func NewMultiTenantOrganizationLoader(persistence *config.MultiTenantPersistence) *MultiTenantOrganizationLoader {
|
||||
return &MultiTenantOrganizationLoader{
|
||||
persistence: persistence,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrganization loads the organization with the specified ID.
|
||||
func (l *MultiTenantOrganizationLoader) GetOrganization(orgID string) (*models.Organization, error) {
|
||||
if l.persistence == nil {
|
||||
return nil, fmt.Errorf("no persistence configured")
|
||||
}
|
||||
return l.persistence.LoadOrganization(orgID)
|
||||
}
|
||||
|
||||
// TokenCanAccessOrg checks if an API token is authorized to access the specified organization.
|
||||
// It uses the token's CanAccessOrg method and logs warnings for legacy tokens.
|
||||
func (c *DefaultAuthorizationChecker) TokenCanAccessOrg(token *config.APITokenRecord, orgID string) bool {
|
||||
if token == nil {
|
||||
// No token means session-based auth - defer to user membership check
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if token can access the org
|
||||
canAccess := token.CanAccessOrg(orgID)
|
||||
|
||||
// Log warning for legacy tokens with wildcard access
|
||||
if token.IsLegacyToken() && orgID != "default" {
|
||||
log.Warn().
|
||||
Str("token_id", token.ID).
|
||||
Str("token_name", token.Name).
|
||||
Str("org_id", orgID).
|
||||
Msg("Legacy token with wildcard access used for non-default org - consider binding to specific org")
|
||||
}
|
||||
|
||||
if !canAccess {
|
||||
log.Debug().
|
||||
Str("token_id", token.ID).
|
||||
Str("token_name", token.Name).
|
||||
Str("org_id", orgID).
|
||||
Strs("bound_orgs", token.GetBoundOrgs()).
|
||||
Msg("Token denied access to organization")
|
||||
}
|
||||
|
||||
return canAccess
|
||||
}
|
||||
|
||||
// UserCanAccessOrg checks if a user is a member of the specified organization.
|
||||
func (c *DefaultAuthorizationChecker) UserCanAccessOrg(userID, orgID string) bool {
|
||||
// Default org is always accessible
|
||||
if orgID == "default" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If no org loader is configured, deny access to non-default orgs
|
||||
if c.orgLoader == nil {
|
||||
log.Warn().
|
||||
Str("user_id", userID).
|
||||
Str("org_id", orgID).
|
||||
Msg("No organization loader configured, denying access to non-default org")
|
||||
return false
|
||||
}
|
||||
|
||||
org, err := c.orgLoader.GetOrganization(orgID)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user_id", userID).
|
||||
Str("org_id", orgID).
|
||||
Msg("Failed to load organization for access check")
|
||||
return false
|
||||
}
|
||||
|
||||
if org == nil {
|
||||
log.Debug().
|
||||
Str("user_id", userID).
|
||||
Str("org_id", orgID).
|
||||
Msg("Organization not found for access check")
|
||||
return false
|
||||
}
|
||||
|
||||
canAccess := org.CanUserAccess(userID)
|
||||
if !canAccess {
|
||||
log.Debug().
|
||||
Str("user_id", userID).
|
||||
Str("org_id", orgID).
|
||||
Msg("User is not a member of the organization")
|
||||
}
|
||||
|
||||
return canAccess
|
||||
}
|
||||
|
||||
// AuthorizationResult contains the result of an authorization check.
|
||||
type AuthorizationResult struct {
|
||||
// Allowed indicates if access is allowed.
|
||||
Allowed bool
|
||||
|
||||
// Reason provides a human-readable reason for the decision.
|
||||
Reason string
|
||||
|
||||
// IsLegacyToken indicates if the access was granted via a legacy wildcard token.
|
||||
IsLegacyToken bool
|
||||
}
|
||||
|
||||
// CheckAccess performs a comprehensive authorization check for a request.
|
||||
func (c *DefaultAuthorizationChecker) CheckAccess(token *config.APITokenRecord, userID, orgID string) AuthorizationResult {
|
||||
// Check token-based access first
|
||||
if token != nil {
|
||||
if !token.CanAccessOrg(orgID) {
|
||||
return AuthorizationResult{
|
||||
Allowed: false,
|
||||
Reason: "Token is not authorized for this organization",
|
||||
}
|
||||
}
|
||||
return AuthorizationResult{
|
||||
Allowed: true,
|
||||
Reason: "Token authorized for organization",
|
||||
IsLegacyToken: token.IsLegacyToken(),
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to user-based access
|
||||
if userID != "" {
|
||||
if c.UserCanAccessOrg(userID, orgID) {
|
||||
return AuthorizationResult{
|
||||
Allowed: true,
|
||||
Reason: "User is a member of the organization",
|
||||
}
|
||||
}
|
||||
return AuthorizationResult{
|
||||
Allowed: false,
|
||||
Reason: "User is not a member of the organization",
|
||||
}
|
||||
}
|
||||
|
||||
// No token and no user - deny access
|
||||
return AuthorizationResult{
|
||||
Allowed: false,
|
||||
Reason: "No authentication context provided",
|
||||
}
|
||||
}
|
||||
|
||||
// CanAccessOrg implements websocket.OrgAuthChecker for use with the WebSocket hub.
|
||||
func (c *DefaultAuthorizationChecker) CanAccessOrg(userID string, tokenInterface interface{}, orgID string) bool {
|
||||
// Default org is always accessible
|
||||
if orgID == "default" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Convert token interface to APITokenRecord
|
||||
var token *config.APITokenRecord
|
||||
if tokenInterface != nil {
|
||||
if t, ok := tokenInterface.(*config.APITokenRecord); ok {
|
||||
token = t
|
||||
}
|
||||
}
|
||||
|
||||
result := c.CheckAccess(token, userID, orgID)
|
||||
return result.Allowed
|
||||
}
|
||||
230
internal/api/middleware_license.go
Normal file
230
internal/api/middleware_license.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/websocket"
|
||||
)
|
||||
|
||||
// Multi-tenant feature flag (default: disabled)
|
||||
// Set PULSE_MULTI_TENANT_ENABLED=true to enable multi-tenant functionality.
|
||||
// This is separate from licensing - the feature must be explicitly enabled
|
||||
// AND properly licensed for non-default organizations to work.
|
||||
var multiTenantEnabled = strings.EqualFold(os.Getenv("PULSE_MULTI_TENANT_ENABLED"), "true")
|
||||
|
||||
// IsMultiTenantEnabled returns whether multi-tenant functionality is enabled.
|
||||
func IsMultiTenantEnabled() bool {
|
||||
return multiTenantEnabled
|
||||
}
|
||||
|
||||
// DefaultMultiTenantChecker implements websocket.MultiTenantChecker for use with the WebSocket hub.
|
||||
type DefaultMultiTenantChecker struct{}
|
||||
|
||||
// CheckMultiTenant checks if multi-tenant is enabled (feature flag) and licensed for the org.
|
||||
// Uses the LicenseServiceProvider for proper per-tenant license lookup.
|
||||
func (c *DefaultMultiTenantChecker) CheckMultiTenant(ctx context.Context, orgID string) websocket.MultiTenantCheckResult {
|
||||
// Default org is always allowed
|
||||
if orgID == "" || orgID == "default" {
|
||||
return websocket.MultiTenantCheckResult{
|
||||
Allowed: true,
|
||||
FeatureEnabled: true,
|
||||
Licensed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Check feature flag first
|
||||
if !multiTenantEnabled {
|
||||
return websocket.MultiTenantCheckResult{
|
||||
Allowed: false,
|
||||
FeatureEnabled: false,
|
||||
Licensed: false,
|
||||
Reason: "Multi-tenant functionality is not enabled",
|
||||
}
|
||||
}
|
||||
|
||||
// Feature is enabled, check license using the provider
|
||||
service := getLicenseServiceForContext(ctx)
|
||||
if !service.HasFeature(license.FeatureMultiTenant) {
|
||||
return websocket.MultiTenantCheckResult{
|
||||
Allowed: false,
|
||||
FeatureEnabled: true,
|
||||
Licensed: false,
|
||||
Reason: "Multi-tenant access requires an Enterprise license",
|
||||
}
|
||||
}
|
||||
|
||||
return websocket.MultiTenantCheckResult{
|
||||
Allowed: true,
|
||||
FeatureEnabled: true,
|
||||
Licensed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMultiTenantChecker creates a new DefaultMultiTenantChecker.
|
||||
func NewMultiTenantChecker() *DefaultMultiTenantChecker {
|
||||
return &DefaultMultiTenantChecker{}
|
||||
}
|
||||
|
||||
// SetMultiTenantEnabled allows programmatic control of the feature flag (for testing).
|
||||
func SetMultiTenantEnabled(enabled bool) {
|
||||
multiTenantEnabled = enabled
|
||||
}
|
||||
|
||||
// LicenseServiceProvider provides license service for a given context.
|
||||
// This allows the middleware to use the properly initialized per-tenant services.
|
||||
type LicenseServiceProvider interface {
|
||||
Service(ctx context.Context) *license.Service
|
||||
}
|
||||
|
||||
var (
|
||||
licenseServiceProvider LicenseServiceProvider
|
||||
licenseServiceMu sync.RWMutex
|
||||
)
|
||||
|
||||
// SetLicenseServiceProvider sets the provider for license services.
|
||||
// This should be called during router initialization with LicenseHandlers.
|
||||
func SetLicenseServiceProvider(provider LicenseServiceProvider) {
|
||||
licenseServiceMu.Lock()
|
||||
defer licenseServiceMu.Unlock()
|
||||
licenseServiceProvider = provider
|
||||
}
|
||||
|
||||
// getLicenseServiceForContext returns the license service for the given context.
|
||||
// Falls back to a new service if no provider is set (shouldn't happen in production).
|
||||
func getLicenseServiceForContext(ctx context.Context) *license.Service {
|
||||
licenseServiceMu.RLock()
|
||||
provider := licenseServiceProvider
|
||||
licenseServiceMu.RUnlock()
|
||||
|
||||
if provider != nil {
|
||||
return provider.Service(ctx)
|
||||
}
|
||||
// Fallback: create a new service (won't have persisted license)
|
||||
return license.NewService()
|
||||
}
|
||||
|
||||
// hasMultiTenantFeatureForContext checks if the multi-tenant feature is licensed for the context.
|
||||
func hasMultiTenantFeatureForContext(ctx context.Context) bool {
|
||||
service := getLicenseServiceForContext(ctx)
|
||||
return service.HasFeature(license.FeatureMultiTenant)
|
||||
}
|
||||
|
||||
// RequireMultiTenant returns a middleware that checks if the multi-tenant feature is licensed.
|
||||
// It allows access to the "default" organization without a license, but requires
|
||||
// an Enterprise license for non-default organizations.
|
||||
func RequireMultiTenant(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := GetOrgID(r.Context())
|
||||
|
||||
// Default org is always allowed (backward compatibility)
|
||||
if orgID == "" || orgID == "default" {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Feature flag check - multi-tenant must be explicitly enabled
|
||||
if !multiTenantEnabled {
|
||||
writeMultiTenantDisabledError(w)
|
||||
return
|
||||
}
|
||||
|
||||
// Non-default orgs require multi-tenant license
|
||||
if !hasMultiTenantFeatureForContext(r.Context()) {
|
||||
writeMultiTenantRequiredError(w)
|
||||
return
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// RequireMultiTenantHandler returns middleware for http.Handler.
|
||||
func RequireMultiTenantHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := GetOrgID(r.Context())
|
||||
|
||||
// Default org is always allowed (backward compatibility)
|
||||
if orgID == "" || orgID == "default" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Feature flag check - multi-tenant must be explicitly enabled
|
||||
if !multiTenantEnabled {
|
||||
writeMultiTenantDisabledError(w)
|
||||
return
|
||||
}
|
||||
|
||||
// Non-default orgs require multi-tenant license
|
||||
if !hasMultiTenantFeatureForContext(r.Context()) {
|
||||
writeMultiTenantRequiredError(w)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// writeMultiTenantRequiredError writes a 402 Payment Required response
|
||||
// indicating that multi-tenant requires an Enterprise license.
|
||||
func writeMultiTenantRequiredError(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusPaymentRequired)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": "license_required",
|
||||
"message": "Multi-tenant access requires an Enterprise license",
|
||||
"feature": license.FeatureMultiTenant,
|
||||
"tier": "enterprise",
|
||||
})
|
||||
}
|
||||
|
||||
// writeMultiTenantDisabledError writes a 501 Not Implemented response
|
||||
// indicating that multi-tenant functionality is not enabled.
|
||||
func writeMultiTenantDisabledError(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": "feature_disabled",
|
||||
"message": "Multi-tenant functionality is not enabled. Set PULSE_MULTI_TENANT_ENABLED=true to enable.",
|
||||
})
|
||||
}
|
||||
|
||||
// CheckMultiTenantLicense checks if multi-tenant is licensed for the given org ID.
|
||||
// Returns true if:
|
||||
// - The org ID is "default" or empty (always allowed)
|
||||
// - The feature flag is enabled AND the multi-tenant feature is licensed
|
||||
// Deprecated: Use CheckMultiTenantLicenseWithContext for proper per-tenant license checking.
|
||||
func CheckMultiTenantLicense(orgID string) bool {
|
||||
if orgID == "" || orgID == "default" {
|
||||
return true
|
||||
}
|
||||
// Feature flag must be enabled
|
||||
if !multiTenantEnabled {
|
||||
return false
|
||||
}
|
||||
// Without context, we can't look up the per-tenant license service properly.
|
||||
// Fall back to a new service (won't have persisted license).
|
||||
return license.NewService().HasFeature(license.FeatureMultiTenant)
|
||||
}
|
||||
|
||||
// CheckMultiTenantLicenseWithContext checks if multi-tenant is enabled and licensed
|
||||
// using the proper per-tenant license service from the context.
|
||||
// Returns true if:
|
||||
// - The org ID is "default" or empty (always allowed)
|
||||
// - The feature flag is enabled AND the multi-tenant feature is licensed
|
||||
func CheckMultiTenantLicenseWithContext(ctx context.Context, orgID string) bool {
|
||||
if orgID == "" || orgID == "default" {
|
||||
return true
|
||||
}
|
||||
// Feature flag must be enabled
|
||||
if !multiTenantEnabled {
|
||||
return false
|
||||
}
|
||||
return hasMultiTenantFeatureForContext(ctx)
|
||||
}
|
||||
@@ -2,29 +2,53 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
"github.com/rcourtman/pulse-go-rewrite/pkg/auth"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type OrganizationContextKey string
|
||||
|
||||
const (
|
||||
OrgIDContextKey OrganizationContextKey = "org_id"
|
||||
OrgContextKey OrganizationContextKey = "org_object"
|
||||
OrgIDContextKey OrganizationContextKey = "org_id"
|
||||
OrgContextKey OrganizationContextKey = "org_object"
|
||||
APITokenContextKey OrganizationContextKey = "api_token_record"
|
||||
)
|
||||
|
||||
// TenantMiddleware extracts the organization ID from the request and
|
||||
// sets up the context for multi-tenant isolation.
|
||||
type TenantMiddleware struct {
|
||||
persistence *config.MultiTenantPersistence
|
||||
authChecker AuthorizationChecker
|
||||
}
|
||||
|
||||
// TenantMiddlewareConfig holds configuration for the tenant middleware.
|
||||
type TenantMiddlewareConfig struct {
|
||||
Persistence *config.MultiTenantPersistence
|
||||
AuthChecker AuthorizationChecker
|
||||
}
|
||||
|
||||
func NewTenantMiddleware(p *config.MultiTenantPersistence) *TenantMiddleware {
|
||||
return &TenantMiddleware{persistence: p}
|
||||
}
|
||||
|
||||
// NewTenantMiddlewareWithConfig creates a new TenantMiddleware with full configuration.
|
||||
func NewTenantMiddlewareWithConfig(cfg TenantMiddlewareConfig) *TenantMiddleware {
|
||||
return &TenantMiddleware{
|
||||
persistence: cfg.Persistence,
|
||||
authChecker: cfg.AuthChecker,
|
||||
}
|
||||
}
|
||||
|
||||
// SetAuthChecker sets the authorization checker for the middleware.
|
||||
func (m *TenantMiddleware) SetAuthChecker(checker AuthorizationChecker) {
|
||||
m.authChecker = checker
|
||||
}
|
||||
|
||||
func (m *TenantMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 1. Extract Org ID
|
||||
@@ -46,19 +70,74 @@ func (m *TenantMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
orgID = "default"
|
||||
}
|
||||
|
||||
// 2. Validate/Load Organization
|
||||
// In a real implementation, we would check if the user has access to this org.
|
||||
// For Phase 1 (Persistence), we just ensure the org is valid in the persistence layer.
|
||||
|
||||
// Ensure the organization persistence is initialized
|
||||
// This creates the directory if it doesn't exist for valid IDs
|
||||
_, err := m.persistence.GetPersistence(orgID)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid Organization ID", http.StatusBadRequest)
|
||||
return
|
||||
// 2. Validate Organization Exists (only for non-default orgs)
|
||||
// Default org is always valid for backward compatibility
|
||||
if orgID != "default" && m.persistence != nil {
|
||||
_, err := m.persistence.GetPersistence(orgID)
|
||||
if err != nil {
|
||||
writeJSONError(w, http.StatusBadRequest, "invalid_org", "Invalid Organization ID")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inject into Context
|
||||
// 2.5 Feature flag and License Check for multi-tenant access
|
||||
// Non-default orgs require:
|
||||
// 1. Feature flag enabled (PULSE_MULTI_TENANT_ENABLED=true) - returns 501 if disabled
|
||||
// 2. Enterprise license - returns 402 if unlicensed
|
||||
if orgID != "default" {
|
||||
// Check feature flag first - 501 Not Implemented if disabled
|
||||
if !IsMultiTenantEnabled() {
|
||||
writeMultiTenantDisabledError(w)
|
||||
return
|
||||
}
|
||||
// Feature is enabled, check license - 402 Payment Required if unlicensed
|
||||
checkCtx := context.WithValue(r.Context(), OrgIDContextKey, orgID)
|
||||
if !hasMultiTenantFeatureForContext(checkCtx) {
|
||||
writeMultiTenantRequiredError(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Authorization Check
|
||||
// Check if the authenticated user/token is allowed to access this organization
|
||||
// Note: This runs AFTER AuthContextMiddleware, so auth context is available
|
||||
if m.authChecker != nil && orgID != "default" {
|
||||
// Get API token from context (set by AuthContextMiddleware)
|
||||
var token *config.APITokenRecord
|
||||
if tokenVal := auth.GetAPIToken(r.Context()); tokenVal != nil {
|
||||
if t, ok := tokenVal.(*config.APITokenRecord); ok {
|
||||
token = t
|
||||
}
|
||||
}
|
||||
|
||||
// Get user ID from context (set by AuthContextMiddleware)
|
||||
userID := auth.GetUser(r.Context())
|
||||
|
||||
// Only perform authorization check if we have auth context
|
||||
// If no auth context, the route's RequireAuth will handle authentication errors
|
||||
if token != nil || userID != "" {
|
||||
// Perform authorization check using the interface method
|
||||
result := m.authChecker.CheckAccess(token, userID, orgID)
|
||||
if !result.Allowed {
|
||||
log.Warn().
|
||||
Str("org_id", orgID).
|
||||
Str("user_id", userID).
|
||||
Str("reason", result.Reason).
|
||||
Msg("Unauthorized access attempt to organization")
|
||||
writeJSONError(w, http.StatusForbidden, "access_denied", result.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
// Log warning for legacy tokens accessing non-default orgs
|
||||
if result.IsLegacyToken {
|
||||
log.Warn().
|
||||
Str("org_id", orgID).
|
||||
Msg("Legacy token with wildcard access used - consider binding to specific org")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Inject into Context
|
||||
ctx := context.WithValue(r.Context(), OrgIDContextKey, orgID)
|
||||
|
||||
// Also store a mock organization object for now
|
||||
@@ -69,6 +148,16 @@ func (m *TenantMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// writeJSONError writes a JSON error response.
|
||||
func writeJSONError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": code,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to get OrgID from context
|
||||
func GetOrgID(ctx context.Context) string {
|
||||
if id, ok := ctx.Value(OrgIDContextKey).(string); ok {
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTenantMiddleware(t *testing.T) {
|
||||
prevMultiTenant := IsMultiTenantEnabled()
|
||||
t.Cleanup(func() {
|
||||
SetMultiTenantEnabled(prevMultiTenant)
|
||||
SetLicenseServiceProvider(nil)
|
||||
})
|
||||
|
||||
// Setup temporary directory for testing
|
||||
tmpDir, err := os.MkdirTemp("", "pulse-tenant-test-*")
|
||||
require.NoError(t, err)
|
||||
@@ -40,6 +49,7 @@ func TestTenantMiddleware(t *testing.T) {
|
||||
handler := middleware.Middleware(testHandler)
|
||||
|
||||
t.Run("Default Org (No Header)", func(t *testing.T) {
|
||||
SetMultiTenantEnabled(false)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -48,12 +58,57 @@ func TestTenantMiddleware(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "OrgID: default", rec.Body.String())
|
||||
|
||||
// Verify default directory was created
|
||||
_, err := os.Stat(filepath.Join(tmpDir, "orgs", "default"))
|
||||
assert.NoError(t, err)
|
||||
// Default org no longer initializes tenant persistence; no directory expectation.
|
||||
})
|
||||
|
||||
t.Run("Custom Org (Header)", func(t *testing.T) {
|
||||
t.Run("Custom Org (Feature Disabled)", func(t *testing.T) {
|
||||
SetMultiTenantEnabled(false)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "customer-a")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotImplemented, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Multi-tenant functionality is not enabled")
|
||||
})
|
||||
|
||||
t.Run("Custom Org (Feature Enabled, Unlicensed)", func(t *testing.T) {
|
||||
SetMultiTenantEnabled(true)
|
||||
SetLicenseServiceProvider(nil)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "customer-a")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusPaymentRequired, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Multi-tenant access requires an Enterprise license")
|
||||
})
|
||||
|
||||
t.Run("Custom Org (Feature Enabled, Licensed)", func(t *testing.T) {
|
||||
SetMultiTenantEnabled(true)
|
||||
// Enable license dev mode for test keys
|
||||
prevDevMode := os.Getenv("PULSE_LICENSE_DEV_MODE")
|
||||
os.Setenv("PULSE_LICENSE_DEV_MODE", "true")
|
||||
t.Cleanup(func() {
|
||||
if prevDevMode == "" {
|
||||
os.Unsetenv("PULSE_LICENSE_DEV_MODE")
|
||||
} else {
|
||||
os.Setenv("PULSE_LICENSE_DEV_MODE", prevDevMode)
|
||||
}
|
||||
})
|
||||
license.SetPublicKey(nil)
|
||||
|
||||
licenseKey, err := license.GenerateLicenseForTesting("test@example.com", license.TierEnterprise, 24*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
service := license.NewService()
|
||||
_, err = service.Activate(licenseKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
SetLicenseServiceProvider(staticLicenseProvider{svc: service})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "customer-a")
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -64,11 +119,12 @@ func TestTenantMiddleware(t *testing.T) {
|
||||
assert.Equal(t, "OrgID: customer-a", rec.Body.String())
|
||||
|
||||
// Verify custom directory was created
|
||||
_, err := os.Stat(filepath.Join(tmpDir, "orgs", "customer-a"))
|
||||
_, err = os.Stat(filepath.Join(tmpDir, "orgs", "customer-a"))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid Org ID (Directory Traversal Attempt)", func(t *testing.T) {
|
||||
SetMultiTenantEnabled(false)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "../../../etc/passwd")
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -78,3 +134,11 @@ func TestTenantMiddleware(t *testing.T) {
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
})
|
||||
}
|
||||
|
||||
type staticLicenseProvider struct {
|
||||
svc *license.Service
|
||||
}
|
||||
|
||||
func (p staticLicenseProvider) Service(ctx context.Context) *license.Service {
|
||||
return p.svc
|
||||
}
|
||||
|
||||
@@ -75,12 +75,32 @@ func (mtp *MultiTenantPersistence) GetPersistence(orgID string) (*ConfigPersiste
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
// LoadOrganizationMetadata loads basic metadata for an organization.
|
||||
// This is separate from the tenant's internal config.
|
||||
// LoadOrganization loads the organization metadata including members.
|
||||
// Org metadata is stored in <orgDir>/org.json.
|
||||
func (mtp *MultiTenantPersistence) LoadOrganization(orgID string) (*models.Organization, error) {
|
||||
// TODO: implementing organization metadata storage in system.json later
|
||||
return &models.Organization{
|
||||
ID: orgID,
|
||||
DisplayName: orgID, // Placeholder
|
||||
}, nil
|
||||
persistence, err := mtp.GetPersistence(orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
org, err := persistence.LoadOrganization()
|
||||
if err != nil {
|
||||
// If org.json doesn't exist, return a default org
|
||||
return &models.Organization{
|
||||
ID: orgID,
|
||||
DisplayName: orgID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return org, nil
|
||||
}
|
||||
|
||||
// SaveOrganization saves the organization metadata.
|
||||
func (mtp *MultiTenantPersistence) SaveOrganization(org *models.Organization) error {
|
||||
persistence, err := mtp.GetPersistence(org.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return persistence.SaveOrganization(org)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,33 @@ package models
|
||||
|
||||
import "time"
|
||||
|
||||
// OrganizationRole represents a user's role within an organization.
|
||||
type OrganizationRole string
|
||||
|
||||
const (
|
||||
// OrgRoleOwner has full access and can manage all aspects of the organization.
|
||||
OrgRoleOwner OrganizationRole = "owner"
|
||||
// OrgRoleAdmin can manage resources but cannot delete the organization.
|
||||
OrgRoleAdmin OrganizationRole = "admin"
|
||||
// OrgRoleMember has read-only access to organization resources.
|
||||
OrgRoleMember OrganizationRole = "member"
|
||||
)
|
||||
|
||||
// OrganizationMember represents a user's membership in an organization.
|
||||
type OrganizationMember struct {
|
||||
// UserID is the unique identifier of the member.
|
||||
UserID string `json:"userId"`
|
||||
|
||||
// Role is the member's role within the organization.
|
||||
Role OrganizationRole `json:"role"`
|
||||
|
||||
// AddedAt is when the member was added to the organization.
|
||||
AddedAt time.Time `json:"addedAt"`
|
||||
|
||||
// AddedBy is the user ID of who added this member (empty for owner).
|
||||
AddedBy string `json:"addedBy,omitempty"`
|
||||
}
|
||||
|
||||
// Organization represents a distinct tenant in the system.
|
||||
type Organization struct {
|
||||
// ID is the unique identifier for the organization (e.g., "customer-a").
|
||||
@@ -14,7 +41,58 @@ type Organization struct {
|
||||
// CreatedAt is when the organization was registered.
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
// OwnerUserID is the primary owner of this organization.
|
||||
// The owner has full administrative rights and cannot be removed.
|
||||
OwnerUserID string `json:"ownerUserId,omitempty"`
|
||||
|
||||
// Members is the list of users who have access to this organization.
|
||||
// This includes the owner (with OrgRoleOwner) and any additional members.
|
||||
Members []OrganizationMember `json:"members,omitempty"`
|
||||
|
||||
// EncryptionKeyID refers to the specific encryption key used for this org's data
|
||||
// (Future proofing for per-tenant encryption keys)
|
||||
EncryptionKeyID string `json:"encryptionKeyId,omitempty"`
|
||||
}
|
||||
|
||||
// HasMember checks if a user is a member of the organization.
|
||||
func (o *Organization) HasMember(userID string) bool {
|
||||
for _, member := range o.Members {
|
||||
if member.UserID == userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetMemberRole returns the role of a user in the organization.
|
||||
// Returns empty string if the user is not a member.
|
||||
func (o *Organization) GetMemberRole(userID string) OrganizationRole {
|
||||
for _, member := range o.Members {
|
||||
if member.UserID == userID {
|
||||
return member.Role
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsOwner checks if a user is the owner of the organization.
|
||||
func (o *Organization) IsOwner(userID string) bool {
|
||||
return o.OwnerUserID == userID
|
||||
}
|
||||
|
||||
// CanUserAccess checks if a user has any level of access to the organization.
|
||||
func (o *Organization) CanUserAccess(userID string) bool {
|
||||
if o.OwnerUserID == userID {
|
||||
return true
|
||||
}
|
||||
return o.HasMember(userID)
|
||||
}
|
||||
|
||||
// CanUserManage checks if a user can manage the organization (owner or admin).
|
||||
func (o *Organization) CanUserManage(userID string) bool {
|
||||
if o.OwnerUserID == userID {
|
||||
return true
|
||||
}
|
||||
role := o.GetMemberRole(userID)
|
||||
return role == OrgRoleOwner || role == OrgRoleAdmin
|
||||
}
|
||||
|
||||
@@ -57,10 +57,16 @@ func (mtm *MultiTenantMonitor) GetMonitor(orgID string) (*Monitor, error) {
|
||||
log.Info().Str("org_id", orgID).Msg("Initializing tenant monitor")
|
||||
|
||||
// 1. Load Tenant Config
|
||||
// We need a specific config for this tenant.
|
||||
// For now, we clone the base config (assuming shared defaults)
|
||||
// In the future, we'll load overrides from persistence.GetPersistence(orgID)
|
||||
tenantConfig := *mtm.baseConfig // Shallow copy
|
||||
// Deep copy the base config to ensure tenant isolation.
|
||||
// Each tenant gets its own independent config that won't share
|
||||
// credential slices or other mutable state with other tenants.
|
||||
tenantConfig := mtm.baseConfig.DeepCopy()
|
||||
|
||||
// Clear inherited credentials - tenants must load their own
|
||||
// This prevents credential leakage between tenants
|
||||
tenantConfig.PVEInstances = nil
|
||||
tenantConfig.PBSInstances = nil
|
||||
tenantConfig.PMGInstances = nil
|
||||
|
||||
// Ensure the DataPath is correct for this tenant to isolate storage (sqlite, etc)
|
||||
tenantPersistence, err := mtm.persistence.GetPersistence(orgID)
|
||||
@@ -69,13 +75,34 @@ func (mtm *MultiTenantMonitor) GetMonitor(orgID string) (*Monitor, error) {
|
||||
}
|
||||
tenantConfig.DataPath = tenantPersistence.GetConfigDir()
|
||||
|
||||
// Load tenant-specific nodes from <orgDir>/nodes.enc
|
||||
nodesConfig, err := tenantPersistence.LoadNodesConfig()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("org_id", orgID).Msg("Failed to load tenant nodes config, starting with empty config")
|
||||
// Not a fatal error - tenant may not have configured any nodes yet
|
||||
} else if nodesConfig != nil {
|
||||
tenantConfig.PVEInstances = nodesConfig.PVEInstances
|
||||
tenantConfig.PBSInstances = nodesConfig.PBSInstances
|
||||
tenantConfig.PMGInstances = nodesConfig.PMGInstances
|
||||
log.Info().
|
||||
Str("org_id", orgID).
|
||||
Int("pve_count", len(nodesConfig.PVEInstances)).
|
||||
Int("pbs_count", len(nodesConfig.PBSInstances)).
|
||||
Int("pmg_count", len(nodesConfig.PMGInstances)).
|
||||
Msg("Loaded tenant nodes config")
|
||||
}
|
||||
|
||||
// 2. Create Monitor
|
||||
// Usage of internal New constructor
|
||||
monitor, err = New(&tenantConfig)
|
||||
monitor, err = New(tenantConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create monitor for org %s: %w", orgID, err)
|
||||
}
|
||||
|
||||
// Set org ID for tenant isolation
|
||||
// This enables tenant-scoped WebSocket broadcasts
|
||||
monitor.SetOrgID(orgID)
|
||||
|
||||
// 3. Start Monitor
|
||||
// We pass the global context, but maybe we should give it a derived one?
|
||||
// Using globalCtx ensures all monitors stop when MultiTenantMonitor stops.
|
||||
|
||||
60
internal/websocket/hub_multitenant_test.go
Normal file
60
internal/websocket/hub_multitenant_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeMultiTenantChecker struct {
|
||||
result MultiTenantCheckResult
|
||||
}
|
||||
|
||||
func (f fakeMultiTenantChecker) CheckMultiTenant(ctx context.Context, orgID string) MultiTenantCheckResult {
|
||||
return f.result
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_MultiTenantDisabled(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
hub.SetMultiTenantChecker(fakeMultiTenantChecker{
|
||||
result: MultiTenantCheckResult{
|
||||
Allowed: false,
|
||||
FeatureEnabled: false,
|
||||
Licensed: false,
|
||||
Reason: "disabled",
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "tenant-a")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
hub.HandleWebSocket(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusNotImplemented, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_MultiTenantUnlicensed(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
hub.SetMultiTenantChecker(fakeMultiTenantChecker{
|
||||
result: MultiTenantCheckResult{
|
||||
Allowed: false,
|
||||
FeatureEnabled: true,
|
||||
Licensed: false,
|
||||
Reason: "unlicensed",
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil)
|
||||
req.Header.Set("X-Pulse-Org-ID", "tenant-a")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
hub.HandleWebSocket(rec, req)
|
||||
|
||||
if rec.Code != http.StatusPaymentRequired {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusPaymentRequired, rec.Code)
|
||||
}
|
||||
}
|
||||
171
pkg/audit/tenant_logger.go
Normal file
171
pkg/audit/tenant_logger.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// TenantLoggerManager manages per-tenant audit loggers.
|
||||
// Each tenant gets their own isolated audit database at <orgDir>/audit.db
|
||||
type TenantLoggerManager struct {
|
||||
mu sync.RWMutex
|
||||
loggers map[string]Logger
|
||||
dataPath string // Base data path
|
||||
factory LoggerFactory // Factory for creating tenant loggers
|
||||
}
|
||||
|
||||
// LoggerFactory creates audit loggers for specific paths.
|
||||
type LoggerFactory interface {
|
||||
// CreateLogger creates a new audit logger at the specified path.
|
||||
CreateLogger(dbPath string) (Logger, error)
|
||||
}
|
||||
|
||||
// DefaultLoggerFactory creates console loggers (for OSS).
|
||||
type DefaultLoggerFactory struct{}
|
||||
|
||||
// CreateLogger creates a console logger (doesn't use the path).
|
||||
func (f *DefaultLoggerFactory) CreateLogger(dbPath string) (Logger, error) {
|
||||
return NewConsoleLogger(), nil
|
||||
}
|
||||
|
||||
// NewTenantLoggerManager creates a new tenant logger manager.
|
||||
func NewTenantLoggerManager(dataPath string, factory LoggerFactory) *TenantLoggerManager {
|
||||
if factory == nil {
|
||||
factory = &DefaultLoggerFactory{}
|
||||
}
|
||||
return &TenantLoggerManager{
|
||||
loggers: make(map[string]Logger),
|
||||
dataPath: dataPath,
|
||||
factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLogger returns the audit logger for a specific organization.
|
||||
// It lazily initializes the logger if it doesn't exist.
|
||||
// For the "default" org, it returns the global logger.
|
||||
func (m *TenantLoggerManager) GetLogger(orgID string) Logger {
|
||||
// Default org uses the global logger
|
||||
if orgID == "" || orgID == "default" {
|
||||
return GetLogger()
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
logger, exists := m.loggers[orgID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return logger
|
||||
}
|
||||
|
||||
// Create new logger for tenant
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if logger, exists = m.loggers[orgID]; exists {
|
||||
return logger
|
||||
}
|
||||
|
||||
// Create tenant-specific logger
|
||||
dbPath := filepath.Join(m.dataPath, "orgs", orgID, "audit.db")
|
||||
logger, err := m.factory.CreateLogger(dbPath)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("org_id", orgID).
|
||||
Str("db_path", dbPath).
|
||||
Msg("Failed to create tenant audit logger, using console logger")
|
||||
logger = NewConsoleLogger()
|
||||
}
|
||||
|
||||
m.loggers[orgID] = logger
|
||||
log.Info().
|
||||
Str("org_id", orgID).
|
||||
Str("db_path", dbPath).
|
||||
Msg("Created tenant audit logger")
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// Log logs an audit event for a specific organization.
|
||||
func (m *TenantLoggerManager) Log(orgID, eventType, user, ip, path string, success bool, details string) error {
|
||||
logger := m.GetLogger(orgID)
|
||||
event := Event{
|
||||
EventType: eventType,
|
||||
User: user,
|
||||
IP: ip,
|
||||
Path: path,
|
||||
Success: success,
|
||||
Details: details,
|
||||
}
|
||||
return logger.Log(event)
|
||||
}
|
||||
|
||||
// Query queries audit events for a specific organization.
|
||||
func (m *TenantLoggerManager) Query(orgID string, filter QueryFilter) ([]Event, error) {
|
||||
logger := m.GetLogger(orgID)
|
||||
return logger.Query(filter)
|
||||
}
|
||||
|
||||
// Count counts audit events for a specific organization.
|
||||
func (m *TenantLoggerManager) Count(orgID string, filter QueryFilter) (int, error) {
|
||||
logger := m.GetLogger(orgID)
|
||||
return logger.Count(filter)
|
||||
}
|
||||
|
||||
// Close closes all tenant loggers.
|
||||
func (m *TenantLoggerManager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for orgID, logger := range m.loggers {
|
||||
if closer, ok := logger.(interface{ Close() error }); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("org_id", orgID).
|
||||
Msg("Failed to close tenant audit logger")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.loggers = make(map[string]Logger)
|
||||
}
|
||||
|
||||
// GetAllLoggers returns all initialized loggers (for administrative purposes).
|
||||
func (m *TenantLoggerManager) GetAllLoggers() map[string]Logger {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]Logger, len(m.loggers))
|
||||
for k, v := range m.loggers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// RemoveTenantLogger removes a specific tenant's logger.
|
||||
// Useful when an organization is deleted.
|
||||
func (m *TenantLoggerManager) RemoveTenantLogger(orgID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
logger, exists := m.loggers[orgID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
if closer, ok := logger.(interface{ Close() error }); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("org_id", orgID).
|
||||
Msg("Failed to close tenant audit logger during removal")
|
||||
}
|
||||
}
|
||||
|
||||
delete(m.loggers, orgID)
|
||||
log.Info().Str("org_id", orgID).Msg("Removed tenant audit logger")
|
||||
}
|
||||
Reference in New Issue
Block a user