mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
feat: remove Enterprise badges, simplify Pro upgrade prompts
- Replace barrel import in AuditLogPanel.tsx to fix ad-blocker crash - Remove all Enterprise/Pro badges from nav and feature headers - Simplify upgrade CTAs to clean 'Upgrade to Pro' links - Update docs: PULSE_PRO.md, API.md, README.md, SECURITY.md - Align terminology: single Pro tier, no separate Enterprise tier Also includes prior refactoring: - Move auth package to pkg/auth for enterprise reuse - Export server functions for testability - Stabilize CLI tests
This commit is contained in:
265
pkg/auth/auth_test.go
Normal file
265
pkg/auth/auth_test.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateAPIToken(t *testing.T) {
|
||||
token, err := GenerateAPIToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAPIToken() error: %v", err)
|
||||
}
|
||||
|
||||
// Should be 64 hex characters (32 bytes)
|
||||
if len(token) != 64 {
|
||||
t.Errorf("GenerateAPIToken() length = %d, want 64", len(token))
|
||||
}
|
||||
|
||||
// Should be valid hex
|
||||
for _, c := range token {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("GenerateAPIToken() contains invalid hex character: %c", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Should generate unique tokens
|
||||
token2, err := GenerateAPIToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAPIToken() second call error: %v", err)
|
||||
}
|
||||
if token == token2 {
|
||||
t.Error("GenerateAPIToken() generated duplicate tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashAPIToken(t *testing.T) {
|
||||
token := "test-token-12345"
|
||||
hash := HashAPIToken(token)
|
||||
|
||||
// Should be 64 hex characters (SHA3-256)
|
||||
if len(hash) != 64 {
|
||||
t.Errorf("HashAPIToken() length = %d, want 64", len(hash))
|
||||
}
|
||||
|
||||
// Should be deterministic
|
||||
hash2 := HashAPIToken(token)
|
||||
if hash != hash2 {
|
||||
t.Error("HashAPIToken() is not deterministic")
|
||||
}
|
||||
|
||||
// Different tokens should produce different hashes
|
||||
hash3 := HashAPIToken("different-token")
|
||||
if hash == hash3 {
|
||||
t.Error("HashAPIToken() produced same hash for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareAPIToken(t *testing.T) {
|
||||
token := "my-secret-token-abc123"
|
||||
hash := HashAPIToken(token)
|
||||
|
||||
// Correct token should match
|
||||
if !CompareAPIToken(token, hash) {
|
||||
t.Error("CompareAPIToken() returned false for correct token")
|
||||
}
|
||||
|
||||
// Wrong token should not match
|
||||
if CompareAPIToken("wrong-token", hash) {
|
||||
t.Error("CompareAPIToken() returned true for wrong token")
|
||||
}
|
||||
|
||||
// Empty token should not match
|
||||
if CompareAPIToken("", hash) {
|
||||
t.Error("CompareAPIToken() returned true for empty token")
|
||||
}
|
||||
|
||||
// Token against wrong hash should not match
|
||||
if CompareAPIToken(token, "0000000000000000000000000000000000000000000000000000000000000000") {
|
||||
t.Error("CompareAPIToken() returned true for wrong hash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareAPIToken_TimingSafe(t *testing.T) {
|
||||
// This test verifies the comparison uses constant-time comparison
|
||||
// by checking that it uses subtle.ConstantTimeCompare internally
|
||||
// (verified by code inspection - this test just ensures the function works)
|
||||
token := "timing-test-token"
|
||||
hash := HashAPIToken(token)
|
||||
|
||||
// Multiple comparisons should all succeed
|
||||
for i := 0; i < 100; i++ {
|
||||
if !CompareAPIToken(token, hash) {
|
||||
t.Errorf("CompareAPIToken() failed on iteration %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAPITokenHashed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid 64-char hex",
|
||||
input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "actual hash",
|
||||
input: HashAPIToken("test"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
input: "0123456789abcdef",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid hex characters",
|
||||
input: "ghijklmnopqrstuv0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "plain token (not hashed)",
|
||||
input: "my-api-token",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "uppercase hex",
|
||||
input: "0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := IsAPITokenHashed(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("IsAPITokenHashed(%q) = %v, want %v", tc.input, result, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
password := "mysecretpassword123"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error: %v", err)
|
||||
}
|
||||
|
||||
// Should be a bcrypt hash (starts with $2)
|
||||
if !strings.HasPrefix(hash, "$2") {
|
||||
t.Errorf("HashPassword() did not produce bcrypt hash, got: %s", hash[:10])
|
||||
}
|
||||
|
||||
// Should be 60 characters
|
||||
if len(hash) != 60 {
|
||||
t.Errorf("HashPassword() length = %d, want 60", len(hash))
|
||||
}
|
||||
|
||||
// Same password should produce different hashes (due to salt)
|
||||
hash2, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() second call error: %v", err)
|
||||
}
|
||||
if hash == hash2 {
|
||||
t.Error("HashPassword() produced same hash twice (missing salt?)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPasswordHash(t *testing.T) {
|
||||
password := "testpassword456"
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error: %v", err)
|
||||
}
|
||||
|
||||
// Correct password should verify
|
||||
if !CheckPasswordHash(password, hash) {
|
||||
t.Error("CheckPasswordHash() returned false for correct password")
|
||||
}
|
||||
|
||||
// Wrong password should not verify
|
||||
if CheckPasswordHash("wrongpassword", hash) {
|
||||
t.Error("CheckPasswordHash() returned true for wrong password")
|
||||
}
|
||||
|
||||
// Empty password should not verify
|
||||
if CheckPasswordHash("", hash) {
|
||||
t.Error("CheckPasswordHash() returned true for empty password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePasswordComplexity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid long password",
|
||||
password: "thisisaverylongpassword",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "exactly minimum length",
|
||||
password: "123456789012", // 12 chars
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
password: "12345678901", // 11 chars
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
password: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "single character",
|
||||
password: "a",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidatePasswordComplexity(tc.password)
|
||||
if tc.wantError && err == nil {
|
||||
t.Errorf("ValidatePasswordComplexity(%q) expected error, got nil", tc.password)
|
||||
}
|
||||
if !tc.wantError && err != nil {
|
||||
t.Errorf("ValidatePasswordComplexity(%q) unexpected error: %v", tc.password, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBcryptCost(t *testing.T) {
|
||||
// Verify bcrypt cost is reasonable (10-14 is typical)
|
||||
if BcryptCost < 10 || BcryptCost > 14 {
|
||||
t.Errorf("BcryptCost = %d, should be between 10 and 14 for reasonable security/performance", BcryptCost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinPasswordLength(t *testing.T) {
|
||||
// Verify minimum password length is reasonable (at least 8, ideally 12+)
|
||||
if MinPasswordLength < 8 {
|
||||
t.Errorf("MinPasswordLength = %d, should be at least 8", MinPasswordLength)
|
||||
}
|
||||
}
|
||||
65
pkg/auth/authorizer.go
Normal file
65
pkg/auth/authorizer.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
// Authorizer defines the interface for making access control decisions.
|
||||
type Authorizer interface {
|
||||
// Authorize checks if a subject (from context) can perform an action on a resource.
|
||||
// Returns true if allowed, false if denied, and an error if the check failed due to a system issue.
|
||||
Authorize(ctx context.Context, action string, resource string) (bool, error)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeyUser contextKey = "user"
|
||||
)
|
||||
|
||||
// WithUser adds a username to the context
|
||||
func WithUser(ctx context.Context, username string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyUser, username)
|
||||
}
|
||||
|
||||
// GetUser extracts the username from the context
|
||||
func GetUser(ctx context.Context) string {
|
||||
if user, ok := ctx.Value(contextKeyUser).(string); ok {
|
||||
return user
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// DefaultAuthorizer is a pass-through implementation that allows everything.
|
||||
// Used in OSS version and when enterprise features are disabled.
|
||||
type DefaultAuthorizer struct{}
|
||||
|
||||
func (d *DefaultAuthorizer) Authorize(ctx context.Context, action string, resource string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var globalAuthorizer Authorizer = &DefaultAuthorizer{}
|
||||
|
||||
// SetAuthorizer sets the global authorizer instance.
|
||||
// This is used by pulse-enterprise to register the real RBAC implementation.
|
||||
func SetAuthorizer(auth Authorizer) {
|
||||
globalAuthorizer = auth
|
||||
}
|
||||
|
||||
// AdminConfigurable is an optional interface for authorizers that can have an admin user set.
|
||||
type AdminConfigurable interface {
|
||||
SetAdminUser(username string)
|
||||
}
|
||||
|
||||
// SetAdminUser sets the admin user on the global authorizer if it supports it.
|
||||
func SetAdminUser(username string) {
|
||||
if username == "" {
|
||||
return
|
||||
}
|
||||
if configurable, ok := globalAuthorizer.(AdminConfigurable); ok {
|
||||
configurable.SetAdminUser(username)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthorizer returns the global authorizer instance.
|
||||
func GetAuthorizer() Authorizer {
|
||||
return globalAuthorizer
|
||||
}
|
||||
31
pkg/auth/coverage_test.go
Normal file
31
pkg/auth/coverage_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashPassword_Error(t *testing.T) {
|
||||
// bcrypt has a max length limit (usually 72 bytes).
|
||||
// Passing a very long password should trigger an error.
|
||||
longPassword := strings.Repeat("A", 80)
|
||||
_, err := HashPassword(longPassword)
|
||||
if err == nil {
|
||||
t.Error("HashPassword() expected error for long password, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAPIToken_Error(t *testing.T) {
|
||||
originalRandRead := randRead
|
||||
defer func() { randRead = originalRandRead }()
|
||||
|
||||
randRead = func(b []byte) (n int, err error) {
|
||||
return 0, errors.New("forced error")
|
||||
}
|
||||
|
||||
_, err := GenerateAPIToken()
|
||||
if err == nil {
|
||||
t.Error("GenerateAPIToken() expected error when rand.Read fails, got nil")
|
||||
}
|
||||
}
|
||||
42
pkg/auth/password.go
Normal file
42
pkg/auth/password.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// BcryptCost is the cost factor for bcrypt hashing
|
||||
// Higher values are more secure but slower
|
||||
BcryptCost = 12
|
||||
|
||||
// MinPasswordLength is the minimum required password length
|
||||
MinPasswordLength = 12
|
||||
)
|
||||
|
||||
// HashPassword generates a bcrypt hash from a plain text password
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// CheckPasswordHash compares a plain text password with a hash
|
||||
func CheckPasswordHash(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ValidatePasswordComplexity checks if a password meets complexity requirements
|
||||
func ValidatePasswordComplexity(password string) error {
|
||||
if len(password) < MinPasswordLength {
|
||||
return fmt.Errorf("password must be at least %d characters long", MinPasswordLength)
|
||||
}
|
||||
|
||||
// Let users choose their own passwords beyond length.
|
||||
// No character type requirements.
|
||||
return nil
|
||||
}
|
||||
18
pkg/auth/permissions.go
Normal file
18
pkg/auth/permissions.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package auth
|
||||
|
||||
// Standard Actions
|
||||
const (
|
||||
ActionRead = "read"
|
||||
ActionWrite = "write"
|
||||
ActionDelete = "delete"
|
||||
ActionAdmin = "admin"
|
||||
)
|
||||
|
||||
// Standard Resources
|
||||
const (
|
||||
ResourceSettings = "settings"
|
||||
ResourceAuditLogs = "audit_logs"
|
||||
ResourceNodes = "nodes"
|
||||
ResourceUsers = "users"
|
||||
ResourceLicense = "license"
|
||||
)
|
||||
45
pkg/auth/token.go
Normal file
45
pkg/auth/token.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
// randRead is a variable to allow mocking in tests
|
||||
var randRead = rand.Read
|
||||
|
||||
// GenerateAPIToken generates a secure random API token
|
||||
func GenerateAPIToken() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := randRead(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// HashAPIToken creates a one-way hash of an API token for storage
|
||||
// We use SHA3-256 for API tokens since we need to compare exact values
|
||||
func HashAPIToken(token string) string {
|
||||
hash := sha3.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// CompareAPIToken compares a provided token with a stored hash
|
||||
func CompareAPIToken(token, hash string) bool {
|
||||
tokenHash := HashAPIToken(token)
|
||||
return subtle.ConstantTimeCompare([]byte(tokenHash), []byte(hash)) == 1
|
||||
}
|
||||
|
||||
// IsAPITokenHashed checks if a string looks like a hashed API token
|
||||
func IsAPITokenHashed(token string) bool {
|
||||
// SHA3-256 produces 64 character hex strings
|
||||
if len(token) != 64 {
|
||||
return false
|
||||
}
|
||||
// Check if it's valid hex
|
||||
_, err := hex.DecodeString(token)
|
||||
return err == nil
|
||||
}
|
||||
390
pkg/server/server.go
Normal file
390
pkg/server/server.go
Normal file
@@ -0,0 +1,390 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/agentbinaries"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/api"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/license"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/logging"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/metrics"
|
||||
_ "github.com/rcourtman/pulse-go-rewrite/internal/mock" // Import for init() to run
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/monitoring"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Version information
|
||||
var (
|
||||
MetricsPort = 9091
|
||||
)
|
||||
|
||||
// Run starts the Pulse monitoring server.
|
||||
func Run(ctx context.Context, version string) error {
|
||||
// Initialize logger with baseline defaults for early startup logs
|
||||
logging.Init(logging.Config{
|
||||
Format: "auto",
|
||||
Level: "info",
|
||||
Component: "pulse",
|
||||
})
|
||||
|
||||
// Check for auto-import on first startup
|
||||
if ShouldAutoImport() {
|
||||
if err := PerformAutoImport(); err != nil {
|
||||
log.Error().Err(err).Msg("Auto-import failed, continuing with normal startup")
|
||||
}
|
||||
}
|
||||
|
||||
// Load unified configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load configuration: %w", err)
|
||||
}
|
||||
|
||||
// Re-initialize logging with configuration-driven settings
|
||||
logging.Init(logging.Config{
|
||||
Format: cfg.LogFormat,
|
||||
Level: cfg.LogLevel,
|
||||
Component: "pulse",
|
||||
FilePath: cfg.LogFile,
|
||||
MaxSizeMB: cfg.LogMaxSize,
|
||||
MaxAgeDays: cfg.LogMaxAge,
|
||||
Compress: cfg.LogCompress,
|
||||
})
|
||||
|
||||
// Initialize license public key for Pro feature validation
|
||||
license.InitPublicKey()
|
||||
|
||||
log.Info().Msg("Starting Pulse monitoring server")
|
||||
|
||||
// Validate agent binaries are available for download
|
||||
agentbinaries.EnsureHostAgentBinaries(version)
|
||||
|
||||
// Create derived context that cancels on interrupt
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Metrics port is configurable via MetricsPort variable
|
||||
metricsAddr := fmt.Sprintf("%s:%d", cfg.BackendHost, MetricsPort)
|
||||
startMetricsServer(ctx, metricsAddr)
|
||||
|
||||
// Initialize WebSocket hub first
|
||||
wsHub := websocket.NewHub(nil)
|
||||
// Set allowed origins from configuration
|
||||
if cfg.AllowedOrigins != "" {
|
||||
if cfg.AllowedOrigins == "*" {
|
||||
// Explicit wildcard - allow all origins (less secure)
|
||||
wsHub.SetAllowedOrigins([]string{"*"})
|
||||
} else {
|
||||
// Use configured origins
|
||||
wsHub.SetAllowedOrigins(strings.Split(cfg.AllowedOrigins, ","))
|
||||
}
|
||||
} else {
|
||||
// Default: don't set any specific origins
|
||||
wsHub.SetAllowedOrigins([]string{})
|
||||
}
|
||||
go wsHub.Run()
|
||||
|
||||
// Initialize reloadable monitoring system
|
||||
reloadableMonitor, err := monitoring.NewReloadableMonitor(cfg, wsHub)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize monitoring system: %w", err)
|
||||
}
|
||||
|
||||
// Set state getter for WebSocket hub
|
||||
wsHub.SetStateGetter(func() interface{} {
|
||||
state := reloadableMonitor.GetMonitor().GetState()
|
||||
return state.ToFrontend()
|
||||
})
|
||||
|
||||
// Wire up Prometheus metrics for alert lifecycle
|
||||
alerts.SetMetricHooks(
|
||||
metrics.RecordAlertFired,
|
||||
metrics.RecordAlertResolved,
|
||||
metrics.RecordAlertSuppressed,
|
||||
metrics.RecordAlertAcknowledged,
|
||||
)
|
||||
log.Info().Msg("Alert metrics hooks registered")
|
||||
|
||||
// Start monitoring
|
||||
reloadableMonitor.Start(ctx)
|
||||
|
||||
// Initialize API server with reload function
|
||||
var router *api.Router
|
||||
reloadFunc := func() error {
|
||||
if err := reloadableMonitor.Reload(); err != nil {
|
||||
return err
|
||||
}
|
||||
if router != nil {
|
||||
router.SetMonitor(reloadableMonitor.GetMonitor())
|
||||
if cfg := reloadableMonitor.GetConfig(); cfg != nil {
|
||||
router.SetConfig(cfg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
router = api.NewRouter(cfg, reloadableMonitor.GetMonitor(), wsHub, reloadFunc, version)
|
||||
|
||||
// Inject resource store into monitor for WebSocket broadcasts
|
||||
router.SetMonitor(reloadableMonitor.GetMonitor())
|
||||
|
||||
// Start AI patrol service for background infrastructure monitoring
|
||||
router.StartPatrol(ctx)
|
||||
|
||||
// Wire alert-triggered AI analysis
|
||||
router.WireAlertTriggeredAI()
|
||||
|
||||
// Create HTTP server with unified configuration
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.BackendHost, cfg.FrontendPort),
|
||||
Handler: router.Handler(),
|
||||
ReadHeaderTimeout: 15 * time.Second,
|
||||
WriteTimeout: 0, // Disabled to support SSE/streaming
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
// Start config watcher for .env file changes
|
||||
configWatcher, err := config.NewConfigWatcher(cfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create config watcher, .env changes will require restart")
|
||||
} else {
|
||||
configWatcher.SetMockReloadCallback(func() {
|
||||
log.Info().Msg("mock.env changed, reloading monitor")
|
||||
if err := reloadableMonitor.Reload(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload monitor after mock.env change")
|
||||
} else if router != nil {
|
||||
router.SetMonitor(reloadableMonitor.GetMonitor())
|
||||
if cfg := reloadableMonitor.GetConfig(); cfg != nil {
|
||||
router.SetConfig(cfg)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
configWatcher.SetAPITokenReloadCallback(func() {
|
||||
if monitor := reloadableMonitor.GetMonitor(); monitor != nil {
|
||||
monitor.RebuildTokenBindings()
|
||||
}
|
||||
})
|
||||
|
||||
if err := configWatcher.Start(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to start config watcher")
|
||||
}
|
||||
defer configWatcher.Stop()
|
||||
}
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
if cfg.HTTPSEnabled && cfg.TLSCertFile != "" && cfg.TLSKeyFile != "" {
|
||||
log.Info().
|
||||
Str("host", cfg.BackendHost).
|
||||
Int("port", cfg.FrontendPort).
|
||||
Str("protocol", "HTTPS").
|
||||
Msg("Server listening")
|
||||
if err := srv.ListenAndServeTLS(cfg.TLSCertFile, cfg.TLSKeyFile); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("Failed to start HTTPS server")
|
||||
}
|
||||
} else {
|
||||
if cfg.HTTPSEnabled {
|
||||
log.Warn().Msg("HTTPS_ENABLED is true but TLS_CERT_FILE or TLS_KEY_FILE not configured, falling back to HTTP")
|
||||
}
|
||||
log.Info().
|
||||
Str("host", cfg.BackendHost).
|
||||
Int("port", cfg.FrontendPort).
|
||||
Str("protocol", "HTTP").
|
||||
Msg("Server listening")
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("Failed to start HTTP server")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Setup signal handlers
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
reloadChan := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
signal.Notify(reloadChan, syscall.SIGHUP)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("Context cancelled, shutting down...")
|
||||
goto shutdown
|
||||
|
||||
case <-reloadChan:
|
||||
log.Info().Msg("Received SIGHUP, reloading configuration...")
|
||||
if configWatcher != nil {
|
||||
configWatcher.ReloadConfig()
|
||||
}
|
||||
|
||||
if err := reloadFunc(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload monitor after SIGHUP")
|
||||
} else {
|
||||
log.Info().Msg("Runtime configuration reloaded")
|
||||
}
|
||||
|
||||
case <-sigChan:
|
||||
log.Info().Msg("Shutting down server...")
|
||||
goto shutdown
|
||||
}
|
||||
}
|
||||
|
||||
shutdown:
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Error().Err(err).Msg("Server shutdown error")
|
||||
}
|
||||
|
||||
cancel()
|
||||
reloadableMonitor.Stop()
|
||||
|
||||
if configWatcher != nil {
|
||||
configWatcher.Stop()
|
||||
}
|
||||
|
||||
log.Info().Msg("Server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// startMetricsServer is moved from main.go
|
||||
func startMetricsServer(ctx context.Context, addr string) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Info().Str("addr", addr).Msg("Metrics server listening")
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("Metrics server failed")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(shutdownCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
func ShouldAutoImport() bool {
|
||||
configPath := os.Getenv("PULSE_DATA_DIR")
|
||||
if configPath == "" {
|
||||
configPath = "/etc/pulse"
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(configPath, "nodes.enc")); err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return os.Getenv("PULSE_INIT_CONFIG_DATA") != "" ||
|
||||
os.Getenv("PULSE_INIT_CONFIG_FILE") != ""
|
||||
}
|
||||
|
||||
func PerformAutoImport() error {
|
||||
configData := os.Getenv("PULSE_INIT_CONFIG_DATA")
|
||||
configFile := os.Getenv("PULSE_INIT_CONFIG_FILE")
|
||||
configPass := os.Getenv("PULSE_INIT_CONFIG_PASSPHRASE")
|
||||
|
||||
if configPass == "" {
|
||||
return fmt.Errorf("PULSE_INIT_CONFIG_PASSPHRASE is required for auto-import")
|
||||
}
|
||||
|
||||
var encryptedData string
|
||||
|
||||
if configFile != "" {
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
payload, err := NormalizeImportPayload(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
encryptedData = payload
|
||||
} else if configData != "" {
|
||||
payload, err := NormalizeImportPayload([]byte(configData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
encryptedData = payload
|
||||
} else {
|
||||
return fmt.Errorf("no config data provided")
|
||||
}
|
||||
|
||||
configPath := os.Getenv("PULSE_DATA_DIR")
|
||||
if configPath == "" {
|
||||
configPath = "/etc/pulse"
|
||||
}
|
||||
|
||||
persistence := config.NewConfigPersistence(configPath)
|
||||
if err := persistence.ImportConfig(encryptedData, configPass); err != nil {
|
||||
return fmt.Errorf("failed to import configuration: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Msg("Configuration auto-imported successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeImportPayload(raw []byte) (string, error) {
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("configuration payload is empty")
|
||||
}
|
||||
|
||||
if decoded, err := base64.StdEncoding.DecodeString(trimmed); err == nil {
|
||||
decodedTrimmed := strings.TrimSpace(string(decoded))
|
||||
if LooksLikeBase64(decodedTrimmed) {
|
||||
return decodedTrimmed, nil
|
||||
}
|
||||
return trimmed, nil
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(raw), nil
|
||||
}
|
||||
|
||||
func LooksLikeBase64(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
compact := strings.Map(func(r rune) rune {
|
||||
switch r {
|
||||
case '\n', '\r', '\t', ' ':
|
||||
return -1
|
||||
default:
|
||||
return r
|
||||
}
|
||||
}, s)
|
||||
|
||||
if compact == "" || len(compact)%4 != 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(compact); i++ {
|
||||
c := compact[i]
|
||||
isAlphaNum := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
|
||||
if isAlphaNum || c == '+' || c == '/' || c == '=' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user