diff --git a/cmd/pulse-sensor-proxy/config.example.yaml b/cmd/pulse-sensor-proxy/config.example.yaml index a912a1954..e85054245 100644 --- a/cmd/pulse-sensor-proxy/config.example.yaml +++ b/cmd/pulse-sensor-proxy/config.example.yaml @@ -10,9 +10,19 @@ allowed_source_subnets: # Peer Authorization # Specify which UIDs/GIDs are allowed to connect +# A peer is authorized when its UID OR GID matches one of these entries # Required when running Pulse in a container (use mapped UID/GID from container) -allowed_peer_uids: [100999] # Container pulse user UID -allowed_peer_gids: [100996] # Container pulse group GID +allowed_peer_uids: [100999] # Legacy format; grants all capabilities unless overridden below +allowed_peer_gids: [100996] + +# Preferred format with explicit capabilities (read, write, admin) +allowed_peers: + - uid: 0 + capabilities: [read, write, admin] # Host root retains full control + - uid: 100999 + capabilities: [read] # Container peer limited to read-only RPCs + +require_proxmox_hostkeys: false # Enforce Proxmox-known host keys before falling back to ssh-keyscan # ID-Mapped Root Authentication # Allow connections from ID-mapped root users (for LXC containers) @@ -24,6 +34,9 @@ allowed_idmap_users: # Address for Prometheus metrics endpoint metrics_address: "127.0.0.1:9127" +# Limit SSH output size (bytes) when fetching temperatures +max_ssh_output_bytes: 1048576 # 1 MiB + # Rate Limiting (Optional) # Control how frequently peers can make requests to prevent abuse # Adjust these values based on your deployment size: diff --git a/cmd/pulse-sensor-proxy/config.go b/cmd/pulse-sensor-proxy/config.go index 70878bf6e..8b259d9f8 100644 --- a/cmd/pulse-sensor-proxy/config.go +++ b/cmd/pulse-sensor-proxy/config.go @@ -6,6 +6,7 @@ import ( "os" "strconv" "strings" + "time" "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" @@ -19,9 +20,16 @@ type RateLimitConfig struct { // Config holds proxy configuration type Config struct { - AllowedSourceSubnets []string `yaml:"allowed_source_subnets"` - MetricsAddress string `yaml:"metrics_address"` - LogLevel string `yaml:"log_level"` + AllowedSourceSubnets []string `yaml:"allowed_source_subnets"` + MetricsAddress string `yaml:"metrics_address"` + LogLevel string `yaml:"log_level"` + AllowedNodes []string `yaml:"allowed_nodes"` + StrictNodeValidation bool `yaml:"strict_node_validation"` + ReadTimeout time.Duration `yaml:"read_timeout"` + WriteTimeout time.Duration `yaml:"write_timeout"` + MaxSSHOutputBytes int64 `yaml:"max_ssh_output_bytes"` + RequireProxmoxHostkeys bool `yaml:"require_proxmox_hostkeys"` + AllowedPeers []PeerConfig `yaml:"allowed_peers"` AllowIDMappedRoot bool `yaml:"allow_idmapped_root"` AllowedPeerUIDs []uint32 `yaml:"allowed_peer_uids"` @@ -31,12 +39,21 @@ type Config struct { RateLimit *RateLimitConfig `yaml:"rate_limit,omitempty"` } +// PeerConfig represents a peer entry with capabilities. +type PeerConfig struct { + UID uint32 `yaml:"uid"` + Capabilities []string `yaml:"capabilities"` +} + // loadConfig loads configuration from file and environment variables func loadConfig(configPath string) (*Config, error) { cfg := &Config{ AllowIDMappedRoot: true, AllowedIDMapUsers: []string{"root"}, LogLevel: "info", // Default log level + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + MaxSSHOutputBytes: 1 * 1024 * 1024, // 1 MiB } // Try to load config file if it exists @@ -58,6 +75,26 @@ func loadConfig(configPath string) (*Config, error) { } } + // Read timeout override + if envReadTimeout := os.Getenv("PULSE_SENSOR_PROXY_READ_TIMEOUT"); envReadTimeout != "" { + if parsed, err := time.ParseDuration(strings.TrimSpace(envReadTimeout)); err != nil { + log.Warn().Str("value", envReadTimeout).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_READ_TIMEOUT value, ignoring") + } else { + cfg.ReadTimeout = parsed + log.Info().Dur("read_timeout", cfg.ReadTimeout).Msg("Configured read timeout from environment") + } + } + + // Write timeout override + if envWriteTimeout := os.Getenv("PULSE_SENSOR_PROXY_WRITE_TIMEOUT"); envWriteTimeout != "" { + if parsed, err := time.ParseDuration(strings.TrimSpace(envWriteTimeout)); err != nil { + log.Warn().Str("value", envWriteTimeout).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_WRITE_TIMEOUT value, ignoring") + } else { + cfg.WriteTimeout = parsed + log.Info().Dur("write_timeout", cfg.WriteTimeout).Msg("Configured write timeout from environment") + } + } + // Append from environment variable if set if envSubnets := os.Getenv("PULSE_SENSOR_PROXY_ALLOWED_SUBNETS"); envSubnets != "" { envList := strings.Split(envSubnets, ",") @@ -67,6 +104,20 @@ func loadConfig(configPath string) (*Config, error) { Msg("Appended subnets from environment variable") } + // Ensure timeouts have sane defaults + if cfg.ReadTimeout <= 0 { + log.Warn().Dur("configured_value", cfg.ReadTimeout).Msg("Read timeout must be positive; using default 5s") + cfg.ReadTimeout = 5 * time.Second + } + if cfg.WriteTimeout <= 0 { + log.Warn().Dur("configured_value", cfg.WriteTimeout).Msg("Write timeout must be positive; using default 10s") + cfg.WriteTimeout = 10 * time.Second + } + if cfg.MaxSSHOutputBytes <= 0 { + log.Warn().Int64("configured_value", cfg.MaxSSHOutputBytes).Msg("max_ssh_output_bytes must be positive; using default 1MiB") + cfg.MaxSSHOutputBytes = 1 * 1024 * 1024 + } + // Allow ID-mapped root override if envAllowIDMap := os.Getenv("PULSE_SENSOR_PROXY_ALLOW_IDMAPPED_ROOT"); envAllowIDMap != "" { parsed, err := parseBool(envAllowIDMap) @@ -126,6 +177,53 @@ func loadConfig(configPath string) (*Config, error) { } } + // Allowed node overrides + if envNodes := os.Getenv("PULSE_SENSOR_PROXY_ALLOWED_NODES"); envNodes != "" { + envList := splitAndTrim(envNodes) + if len(envList) > 0 { + cfg.AllowedNodes = append(cfg.AllowedNodes, envList...) + log.Info(). + Int("env_allowed_nodes", len(envList)). + Msg("Appended allowed nodes from environment") + } + } + + // Strict node validation override + if envStrict := os.Getenv("PULSE_SENSOR_PROXY_STRICT_NODE_VALIDATION"); envStrict != "" { + parsed, err := parseBool(envStrict) + if err != nil { + log.Warn(). + Str("value", envStrict). + Err(err). + Msg("Invalid PULSE_SENSOR_PROXY_STRICT_NODE_VALIDATION value, ignoring") + } else { + cfg.StrictNodeValidation = parsed + log.Info(). + Bool("strict_node_validation", parsed). + Msg("Configured strict node validation from environment") + } + } + + // SSH output limit override + if envMaxSSH := os.Getenv("PULSE_SENSOR_PROXY_MAX_SSH_OUTPUT_BYTES"); envMaxSSH != "" { + if parsed, err := strconv.ParseInt(strings.TrimSpace(envMaxSSH), 10, 64); err != nil { + log.Warn().Str("value", envMaxSSH).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_MAX_SSH_OUTPUT_BYTES value, ignoring") + } else { + cfg.MaxSSHOutputBytes = parsed + log.Info().Int64("max_ssh_output_bytes", cfg.MaxSSHOutputBytes).Msg("Configured max SSH output bytes from environment") + } + } + + // Require Proxmox host keys override + if envReq := os.Getenv("PULSE_SENSOR_PROXY_REQUIRE_PROXMOX_HOSTKEYS"); envReq != "" { + if parsed, err := parseBool(envReq); err != nil { + log.Warn().Str("value", envReq).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_REQUIRE_PROXMOX_HOSTKEYS value, ignoring") + } else { + cfg.RequireProxmoxHostkeys = parsed + log.Info().Bool("require_proxmox_hostkeys", parsed).Msg("Configured Proxmox host key requirement from environment") + } + } + // Metrics address from environment variable if envMetrics := os.Getenv("PULSE_SENSOR_PROXY_METRICS_ADDR"); envMetrics != "" { cfg.MetricsAddress = envMetrics diff --git a/cmd/pulse-sensor-proxy/main.go b/cmd/pulse-sensor-proxy/main.go index c6e88c3fd..c04c049a8 100644 --- a/cmd/pulse-sensor-proxy/main.go +++ b/cmd/pulse-sensor-proxy/main.go @@ -249,20 +249,25 @@ func lookupUserFromPasswd(username string) (*userSpec, error) { // Proxy manages the temperature monitoring proxy type Proxy struct { - socketPath string - sshKeyPath string - workDir string - knownHosts knownhosts.Manager - listener net.Listener - rateLimiter *rateLimiter - nodeGate *nodeGate - router map[string]handlerFunc - config *Config - metrics *ProxyMetrics - audit *auditLogger + socketPath string + sshKeyPath string + workDir string + knownHosts knownhosts.Manager + listener net.Listener + rateLimiter *rateLimiter + nodeGate *nodeGate + router map[string]handlerFunc + config *Config + metrics *ProxyMetrics + audit *auditLogger + nodeValidator *nodeValidator + readTimeout time.Duration + writeTimeout time.Duration + maxSSHOutputBytes int64 allowedPeerUIDs map[uint32]struct{} allowedPeerGIDs map[uint32]struct{} + peerCapabilities map[uint32]Capability idMappedUIDRanges []idRange idMappedGIDRanges []idRange } @@ -362,6 +367,11 @@ func runProxy() { // Initialize metrics metrics := NewProxyMetrics(Version) + nodeValidator, err := newNodeValidator(cfg, metrics) + if err != nil { + log.Fatal().Err(err).Msg("Failed to initialize node validator") + } + log.Info(). Str("socket", socketPath). Str("ssh_key_dir", sshKeyPath). @@ -383,14 +393,17 @@ func runProxy() { } proxy := &Proxy{ - socketPath: socketPath, - sshKeyPath: sshKeyPath, - knownHosts: knownHostsManager, - rateLimiter: newRateLimiter(metrics, cfg.RateLimit), - nodeGate: newNodeGate(), - config: cfg, - metrics: metrics, - audit: auditLogger, + socketPath: socketPath, + sshKeyPath: sshKeyPath, + knownHosts: knownHostsManager, + nodeGate: newNodeGate(), + config: cfg, + metrics: metrics, + audit: auditLogger, + nodeValidator: nodeValidator, + readTimeout: cfg.ReadTimeout, + writeTimeout: cfg.WriteTimeout, + maxSSHOutputBytes: cfg.MaxSSHOutputBytes, } if wd, err := os.Getwd(); err == nil { @@ -400,6 +413,12 @@ func runProxy() { proxy.workDir = defaultWorkDir() } + if err := proxy.initAuthRules(); err != nil { + log.Fatal().Err(err).Msg("Failed to initialize authentication rules") + } + + proxy.rateLimiter = newRateLimiter(metrics, cfg.RateLimit, proxy.idMappedUIDRanges, proxy.idMappedGIDRanges) + // Register RPC method handlers proxy.router = map[string]handlerFunc{ RPCGetStatus: proxy.handleGetStatusV2, @@ -409,10 +428,6 @@ func runProxy() { RPCRequestCleanup: proxy.handleRequestCleanup, } - if err := proxy.initAuthRules(); err != nil { - log.Fatal().Err(err).Msg("Failed to initialize authentication rules") - } - if err := proxy.Start(); err != nil { log.Fatal().Err(err).Msg("Failed to start proxy") } @@ -429,7 +444,9 @@ func runProxy() { <-sigChan log.Info().Msg("Shutting down proxy...") proxy.Stop() - proxy.rateLimiter.shutdown() + if proxy.rateLimiter != nil { + proxy.rateLimiter.shutdown() + } metrics.Shutdown(context.Background()) log.Info().Msg("Proxy stopped") } @@ -518,8 +535,14 @@ func (p *Proxy) handleConnection(conn net.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Skip read deadline - it interferes with write operations on unix sockets - // Context timeout provides sufficient protection against hung connections + // Enforce read deadline to prevent hung connections + readDeadline := p.readTimeout + if readDeadline <= 0 { + readDeadline = 5 * time.Second + } + if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + log.Warn().Err(err).Msg("Failed to set read deadline") + } // Extract and verify peer credentials cred, err := extractPeerCredentials(conn) @@ -532,7 +555,8 @@ func (p *Proxy) handleConnection(conn net.Conn) { return } - if err := p.authorizePeer(cred); err != nil { + peerCaps, err := p.authorizePeer(cred) + if err != nil { log.Warn(). Err(err). Uint32("uid", cred.uid). @@ -549,8 +573,15 @@ func (p *Proxy) handleConnection(conn net.Conn) { p.audit.LogConnectionAccepted("", cred, remoteAddr) } + if p.rateLimiter == nil { + log.Error().Msg("Rate limiter not initialized; rejecting connection") + p.sendErrorV2(conn, "service unavailable", "") + return + } + // Check rate limit and concurrency - peer := peerID{uid: cred.uid} + peer := p.rateLimiter.identifyPeer(cred) + peerLabel := peer.String() releaseLimiter, limitReason, allowed := p.rateLimiter.allow(peer) if !allowed { log.Warn(). @@ -575,7 +606,7 @@ func (p *Proxy) handleConnection(conn net.Conn) { releaseFn() releaseFn = nil } - p.rateLimiter.penalize(peer, reason) + p.rateLimiter.penalize(peerLabel, reason) } // Read request using newline-delimited framing @@ -584,6 +615,21 @@ func (p *Proxy) handleConnection(conn net.Conn) { line, err := reader.ReadBytes('\n') if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + log.Warn(). + Err(err). + Str("remote", remoteAddr). + Msg("Read timeout waiting for client request - slow client or attack") + if p.metrics != nil { + p.metrics.recordReadTimeout() + } + if p.audit != nil { + p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "read_timeout") + } + p.sendErrorV2(conn, "read timeout", "") + applyPenalty("read_timeout") + return + } if errors.Is(err, bufio.ErrBufferFull) || limited.N <= 0 { if p.audit != nil { p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "payload_too_large") @@ -608,6 +654,11 @@ func (p *Proxy) handleConnection(conn net.Conn) { return } + // Clear read deadline now that request payload has been received + if err := conn.SetReadDeadline(time.Time{}); err != nil { + log.Warn().Err(err).Msg("Failed to clear read deadline after request read") + } + // Trim whitespace and validate line = bytes.TrimSpace(line) if len(line) == 0 { @@ -655,29 +706,28 @@ func (p *Proxy) handleConnection(conn net.Conn) { } resp.Error = "unknown method" logger.Warn().Msg("Unknown method") - p.sendResponse(conn, resp) + p.sendResponse(conn, resp, p.writeTimeout) applyPenalty("unknown_method") return } // Check if method requires host-level privileges if privilegedMethods[req.Method] { - // Privileged methods can only be called from host (not from containers) - if p.isIDMappedRoot(cred) { - resp.Error = "method requires host-level privileges" + if !peerCaps.Has(CapabilityAdmin) { + resp.Error = "method requires admin capability" log.Warn(). Str("method", req.Method). Uint32("uid", cred.uid). Uint32("gid", cred.gid). Uint32("pid", cred.pid). Str("corr_id", req.CorrelationID). - Msg("SECURITY: Container attempted to call privileged method - access denied") + Msg("SECURITY: peer lacking admin capability attempted privileged method - access denied") if p.audit != nil { - p.audit.LogValidationFailure(req.CorrelationID, cred, remoteAddr, req.Method, nil, "privileged_method_denied") + p.audit.LogValidationFailure(req.CorrelationID, cred, remoteAddr, req.Method, nil, "capability_denied") } - p.sendResponse(conn, resp) + p.sendResponse(conn, resp, p.writeTimeout) p.metrics.rpcRequests.WithLabelValues(req.Method, "unauthorized").Inc() - applyPenalty("privileged_method_denied") + applyPenalty("capability_denied") return } } @@ -695,10 +745,7 @@ func (p *Proxy) handleConnection(conn net.Conn) { } resp.Error = err.Error() logger.Warn().Err(err).Msg("Handler failed") - // Clear read deadline and set write deadline for error response - conn.SetReadDeadline(time.Time{}) - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - p.sendResponse(conn, resp) + p.sendResponse(conn, resp, p.writeTimeout) // Record failed request p.metrics.rpcRequests.WithLabelValues(req.Method, "error").Inc() p.metrics.rpcLatency.WithLabelValues(req.Method).Observe(time.Since(startTime).Seconds()) @@ -713,10 +760,7 @@ func (p *Proxy) handleConnection(conn net.Conn) { } logger.Info().Msg("Request completed") - // Clear read deadline and set write deadline for response - conn.SetReadDeadline(time.Time{}) - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - p.sendResponse(conn, resp) + p.sendResponse(conn, resp, p.writeTimeout) // Record successful request p.metrics.rpcRequests.WithLabelValues(req.Method, "success").Inc() @@ -729,8 +773,7 @@ func (p *Proxy) sendError(conn net.Conn, message string) { Success: false, Error: message, } - encoder := json.NewEncoder(conn) - encoder.Encode(resp) + p.sendResponse(conn, resp, p.writeTimeout) } // sendErrorV2 sends an error response with correlation ID @@ -740,25 +783,34 @@ func (p *Proxy) sendErrorV2(conn net.Conn, message, correlationID string) { Success: false, Error: message, } - // Clear read deadline before writing - conn.SetReadDeadline(time.Time{}) - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - encoder := json.NewEncoder(conn) - encoder.Encode(resp) + p.sendResponse(conn, resp, p.writeTimeout) } // sendResponse sends an RPC response -func (p *Proxy) sendResponse(conn net.Conn, resp RPCResponse) { +func (p *Proxy) sendResponse(conn net.Conn, resp RPCResponse, writeTimeout time.Duration) { // Clear read deadline before writing if err := conn.SetReadDeadline(time.Time{}); err != nil { log.Warn().Err(err).Msg("Failed to clear read deadline") } - if err := conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { - log.Warn().Err(err).Msg("Failed to set write deadline") + + if writeTimeout <= 0 { + writeTimeout = 10 * time.Second } + + if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + log.Warn().Err(err).Dur("write_timeout", writeTimeout).Msg("Failed to set write deadline") + } + encoder := json.NewEncoder(conn) if err := encoder.Encode(resp); err != nil { - log.Error().Err(err).Msg("Failed to encode RPC response") + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + log.Warn().Err(err).Msg("Write timeout while sending RPC response") + if p.metrics != nil { + p.metrics.recordWriteTimeout() + } + } else { + log.Error().Err(err).Msg("Failed to encode RPC response") + } } } @@ -1101,6 +1153,13 @@ func (p *Proxy) handleGetTemperatureV2(ctx context.Context, req *RPCRequest, log return nil, fmt.Errorf("invalid node name") } + if p.nodeValidator != nil { + if err := p.nodeValidator.Validate(ctx, node); err != nil { + logger.Warn().Err(err).Str("node", node).Msg("Node validation failed") + return nil, err + } + } + // Acquire per-node concurrency lock (prevents multiple simultaneous requests to same node) releaseNode := p.nodeGate.acquire(node) defer releaseNode() diff --git a/cmd/pulse-sensor-proxy/metrics.go b/cmd/pulse-sensor-proxy/metrics.go index 5dac63f6b..71ea003d2 100644 --- a/cmd/pulse-sensor-proxy/metrics.go +++ b/cmd/pulse-sensor-proxy/metrics.go @@ -16,19 +16,24 @@ const defaultMetricsAddr = "127.0.0.1:9127" // ProxyMetrics holds Prometheus metrics for the proxy type ProxyMetrics struct { - rpcRequests *prometheus.CounterVec - rpcLatency *prometheus.HistogramVec - sshRequests *prometheus.CounterVec - sshLatency *prometheus.HistogramVec - queueDepth prometheus.Gauge - rateLimitHits prometheus.Counter - limiterRejects *prometheus.CounterVec - globalConcurrency prometheus.Gauge - limiterPenalties *prometheus.CounterVec - limiterPeers prometheus.Gauge - buildInfo *prometheus.GaugeVec - server *http.Server - registry *prometheus.Registry + rpcRequests *prometheus.CounterVec + rpcLatency *prometheus.HistogramVec + sshRequests *prometheus.CounterVec + sshLatency *prometheus.HistogramVec + queueDepth prometheus.Gauge + rateLimitHits prometheus.Counter + limiterRejects *prometheus.CounterVec + globalConcurrency prometheus.Gauge + limiterPenalties *prometheus.CounterVec + limiterPeers prometheus.Gauge + nodeValidationFailures *prometheus.CounterVec + readTimeouts prometheus.Counter + writeTimeouts prometheus.Counter + hostKeyChanges *prometheus.CounterVec + sshOutputOversized *prometheus.CounterVec + buildInfo *prometheus.GaugeVec + server *http.Server + registry *prometheus.Registry } // NewProxyMetrics creates and registers all metrics @@ -83,7 +88,7 @@ func NewProxyMetrics(version string) *ProxyMetrics { Name: "pulse_proxy_limiter_rejections_total", Help: "Limiter rejections by reason.", }, - []string{"reason"}, + []string{"reason", "peer"}, ), globalConcurrency: prometheus.NewGauge( prometheus.GaugeOpts{ @@ -96,7 +101,7 @@ func NewProxyMetrics(version string) *ProxyMetrics { Name: "pulse_proxy_limiter_penalties_total", Help: "Penalty sleeps applied after validation failures.", }, - []string{"reason"}, + []string{"reason", "peer"}, ), limiterPeers: prometheus.NewGauge( prometheus.GaugeOpts{ @@ -104,6 +109,39 @@ func NewProxyMetrics(version string) *ProxyMetrics { Help: "Number of peers tracked by the rate limiter.", }, ), + nodeValidationFailures: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pulse_proxy_node_validation_failures_total", + Help: "Node validation failures by reason.", + }, + []string{"reason"}, + ), + readTimeouts: prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "pulse_proxy_read_timeouts_total", + Help: "Number of socket read timeouts.", + }, + ), + writeTimeouts: prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "pulse_proxy_write_timeouts_total", + Help: "Number of socket write timeouts.", + }, + ), + hostKeyChanges: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pulse_proxy_hostkey_changes_total", + Help: "Detected SSH host key changes by node.", + }, + []string{"node"}, + ), + sshOutputOversized: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pulse_proxy_ssh_output_oversized_total", + Help: "Number of SSH responses rejected for exceeding size limits.", + }, + []string{"node"}, + ), buildInfo: prometheus.NewGaugeVec( prometheus.GaugeOpts{ Name: "pulse_proxy_build_info", @@ -125,6 +163,11 @@ func NewProxyMetrics(version string) *ProxyMetrics { pm.globalConcurrency, pm.limiterPenalties, pm.limiterPeers, + pm.nodeValidationFailures, + pm.readTimeouts, + pm.writeTimeouts, + pm.hostKeyChanges, + pm.sshOutputOversized, pm.buildInfo, ) @@ -200,12 +243,53 @@ func sanitizeNodeLabel(node string) string { return out } -func (m *ProxyMetrics) recordLimiterReject(reason string) { +func (m *ProxyMetrics) recordLimiterReject(reason, peer string) { if m == nil { return } m.rateLimitHits.Inc() - m.limiterRejects.WithLabelValues(reason).Inc() + m.limiterRejects.WithLabelValues(reason, peer).Inc() +} + +func (m *ProxyMetrics) recordNodeValidationFailure(reason string) { + if m == nil { + return + } + m.nodeValidationFailures.WithLabelValues(reason).Inc() +} + +func (m *ProxyMetrics) recordReadTimeout() { + if m == nil { + return + } + m.readTimeouts.Inc() +} + +func (m *ProxyMetrics) recordWriteTimeout() { + if m == nil { + return + } + m.writeTimeouts.Inc() +} + +func (m *ProxyMetrics) recordSSHOutputOversized(node string) { + if m == nil { + return + } + if node == "" { + node = "unknown" + } + m.sshOutputOversized.WithLabelValues(sanitizeNodeLabel(node)).Inc() +} + +func (m *ProxyMetrics) recordHostKeyChange(node string) { + if m == nil { + return + } + if node == "" { + node = "unknown" + } + m.hostKeyChanges.WithLabelValues(sanitizeNodeLabel(node)).Inc() } func (m *ProxyMetrics) incGlobalConcurrency() { @@ -222,11 +306,11 @@ func (m *ProxyMetrics) decGlobalConcurrency() { m.globalConcurrency.Dec() } -func (m *ProxyMetrics) recordPenalty(reason string) { +func (m *ProxyMetrics) recordPenalty(reason, peer string) { if m == nil { return } - m.limiterPenalties.WithLabelValues(reason).Inc() + m.limiterPenalties.WithLabelValues(reason, peer).Inc() } func (m *ProxyMetrics) setLimiterPeers(count int) { diff --git a/cmd/pulse-sensor-proxy/validation.go b/cmd/pulse-sensor-proxy/validation.go index 8eb2d5356..4709e5bdc 100644 --- a/cmd/pulse-sensor-proxy/validation.go +++ b/cmd/pulse-sensor-proxy/validation.go @@ -1,15 +1,19 @@ package main import ( + "context" "errors" "fmt" "net" "regexp" "strings" + "sync" + "time" "unicode" "unicode/utf8" "github.com/google/uuid" + "github.com/rs/zerolog/log" ) var ( @@ -204,3 +208,313 @@ func isASCII(s string) bool { } return true } + +const ( + nodeValidatorCacheTTL = 5 * time.Minute + + validationReasonNotAllowlisted = "node_not_allowlisted" + validationReasonNotClusterMember = "node_not_cluster_member" + validationReasonNoSources = "no_validation_sources" + validationReasonResolutionFailed = "allowlist_resolution_failed" + validationReasonClusterFailed = "cluster_query_failed" +) + +// nodeValidator enforces node allow-list and cluster membership checks +type nodeValidator struct { + allowHosts map[string]struct{} + allowCIDRs []*net.IPNet + hasAllowlist bool + strict bool + clusterEnabled bool + metrics *ProxyMetrics + resolver hostResolver + clusterFetcher func() ([]string, error) + cacheTTL time.Duration + clock func() time.Time + clusterCache clusterMembershipCache +} + +type clusterMembershipCache struct { + mu sync.Mutex + expires time.Time + nodes map[string]struct{} +} + +type hostResolver interface { + LookupIP(ctx context.Context, host string) ([]net.IP, error) +} + +type defaultHostResolver struct{} + +func (defaultHostResolver) LookupIP(ctx context.Context, host string) ([]net.IP, error) { + if ctx == nil { + ctx = context.Background() + } + results, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("no IPs resolved for %s", host) + } + + ips := make([]net.IP, 0, len(results)) + for _, addr := range results { + if addr.IP != nil { + ips = append(ips, addr.IP) + } + } + if len(ips) == 0 { + return nil, fmt.Errorf("no IP addresses resolved for %s", host) + } + + return ips, nil +} + +func newNodeValidator(cfg *Config, metrics *ProxyMetrics) (*nodeValidator, error) { + if cfg == nil { + return nil, errors.New("config is required for node validator") + } + + v := &nodeValidator{ + allowHosts: make(map[string]struct{}), + strict: cfg.StrictNodeValidation, + metrics: metrics, + resolver: defaultHostResolver{}, + clusterFetcher: discoverClusterNodes, + cacheTTL: nodeValidatorCacheTTL, + clock: time.Now, + } + + for _, raw := range cfg.AllowedNodes { + entry := strings.TrimSpace(raw) + if entry == "" { + continue + } + + if _, network, err := net.ParseCIDR(entry); err == nil { + v.allowCIDRs = append(v.allowCIDRs, network) + continue + } + + if normalized := normalizeAllowlistEntry(entry); normalized != "" { + v.allowHosts[normalized] = struct{}{} + } + } + + v.hasAllowlist = len(v.allowHosts) > 0 || len(v.allowCIDRs) > 0 + + if v.hasAllowlist { + log.Info(). + Int("allowed_node_count", len(v.allowHosts)). + Int("allowed_cidr_count", len(v.allowCIDRs)). + Msg("Node allow-list configured") + } + + if !v.hasAllowlist && isProxmoxHost() { + v.clusterEnabled = true + log.Info().Msg("Node validator using Proxmox cluster membership (auto-detect)") + } + + if !v.clusterEnabled { + v.clusterFetcher = nil + } + + if !v.hasAllowlist && !v.clusterEnabled { + if v.strict { + log.Warn().Msg("strict_node_validation enabled but no allowlist or cluster context is available") + } else { + log.Info().Msg("Node validator running in permissive mode (no allowlist or cluster context)") + } + } + + return v, nil +} + +// Validate ensures the provided node is authorized before any SSH is attempted. +func (v *nodeValidator) Validate(ctx context.Context, node string) error { + if v == nil { + return nil + } + + if ctx == nil { + ctx = context.Background() + } + + if v.hasAllowlist { + allowed, err := v.matchesAllowlist(ctx, node) + if err != nil { + v.recordFailure(validationReasonResolutionFailed) + log.Warn().Err(err).Str("node", node).Msg("Node allow-list resolution failed") + return err + } + if !allowed { + return v.deny(node, validationReasonNotAllowlisted) + } + return nil + } + + if v.clusterEnabled { + allowed, err := v.matchesCluster(node) + if err != nil { + v.recordFailure(validationReasonClusterFailed) + return fmt.Errorf("failed to evaluate cluster membership: %w", err) + } + if !allowed { + return v.deny(node, validationReasonNotClusterMember) + } + return nil + } + + if v.strict { + return v.deny(node, validationReasonNoSources) + } + + return nil +} + +func (v *nodeValidator) matchesAllowlist(ctx context.Context, node string) (bool, error) { + normalized := normalizeAllowlistEntry(node) + if normalized != "" { + if _, ok := v.allowHosts[normalized]; ok { + return true, nil + } + } + + if ip := parseNodeIP(node); ip != nil { + if v.ipAllowed(ip) { + return true, nil + } + // If the node itself is an IP and it didn't match, no need to resolve again. + return false, nil + } + + if len(v.allowCIDRs) == 0 { + return false, nil + } + + host := stripNodeDelimiters(node) + ips, err := v.resolver.LookupIP(ctx, host) + if err != nil { + return false, fmt.Errorf("resolve node %q: %w", host, err) + } + + for _, ip := range ips { + if v.ipAllowed(ip) { + return true, nil + } + } + + return false, nil +} + +func (v *nodeValidator) matchesCluster(node string) (bool, error) { + if v.clusterFetcher == nil { + return false, errors.New("cluster membership disabled") + } + + members, err := v.getClusterMembers() + if err != nil { + return false, err + } + + normalized := normalizeAllowlistEntry(node) + if normalized == "" { + normalized = strings.ToLower(strings.TrimSpace(node)) + } + + _, ok := members[normalized] + return ok, nil +} + +func (v *nodeValidator) getClusterMembers() (map[string]struct{}, error) { + now := time.Now() + if v.clock != nil { + now = v.clock() + } + + v.clusterCache.mu.Lock() + defer v.clusterCache.mu.Unlock() + + if v.clusterCache.nodes != nil && now.Before(v.clusterCache.expires) { + return v.clusterCache.nodes, nil + } + + nodes, err := v.clusterFetcher() + if err != nil { + return nil, err + } + + result := make(map[string]struct{}, len(nodes)) + for _, node := range nodes { + if normalized := normalizeAllowlistEntry(node); normalized != "" { + result[normalized] = struct{}{} + } + } + + ttl := v.cacheTTL + if ttl <= 0 { + ttl = nodeValidatorCacheTTL + } + v.clusterCache.nodes = result + v.clusterCache.expires = now.Add(ttl) + log.Debug(). + Int("cluster_node_count", len(result)). + Msg("Refreshed cluster membership cache") + + return result, nil +} + +func (v *nodeValidator) ipAllowed(ip net.IP) bool { + if ip == nil { + return false + } + if _, ok := v.allowHosts[ip.String()]; ok { + return true + } + for _, network := range v.allowCIDRs { + if network.Contains(ip) { + return true + } + } + return false +} + +func (v *nodeValidator) recordFailure(reason string) { + if v.metrics != nil { + v.metrics.recordNodeValidationFailure(reason) + } +} + +func (v *nodeValidator) deny(node, reason string) error { + v.recordFailure(reason) + log.Warn(). + Str("node", node). + Str("reason", reason). + Msg("potential SSRF attempt blocked") + return fmt.Errorf("node %q rejected by validator (%s)", node, reason) +} + +func normalizeAllowlistEntry(entry string) string { + candidate := strings.TrimSpace(entry) + if candidate == "" { + return "" + } + unwrapped := stripNodeDelimiters(candidate) + if ip := net.ParseIP(unwrapped); ip != nil { + return ip.String() + } + return strings.ToLower(candidate) +} + +func parseNodeIP(node string) net.IP { + clean := stripNodeDelimiters(strings.TrimSpace(node)) + return net.ParseIP(clean) +} + +func stripNodeDelimiters(node string) string { + if strings.HasPrefix(node, "[") && strings.HasSuffix(node, "]") && len(node) > 2 { + return node[1 : len(node)-1] + } + return node +} diff --git a/cmd/pulse-sensor-proxy/validation_test.go b/cmd/pulse-sensor-proxy/validation_test.go index bb31654d0..692d1ac3f 100644 --- a/cmd/pulse-sensor-proxy/validation_test.go +++ b/cmd/pulse-sensor-proxy/validation_test.go @@ -1,8 +1,11 @@ package main import ( + "context" + "net" "strings" "testing" + "time" ) func TestSanitizeCorrelationID(t *testing.T) { @@ -121,3 +124,97 @@ func TestValidateCommand(t *testing.T) { }) } } + +type stubResolver struct { + ips []net.IP + err error +} + +func (s stubResolver) LookupIP(ctx context.Context, host string) ([]net.IP, error) { + if s.err != nil { + return nil, s.err + } + return s.ips, nil +} + +func TestNodeValidatorAllowlistHost(t *testing.T) { + v := &nodeValidator{ + allowHosts: map[string]struct{}{"node-1": {}}, + hasAllowlist: true, + resolver: stubResolver{}, + } + + if err := v.Validate(context.Background(), "node-1"); err != nil { + t.Fatalf("expected node-1 to be permitted, got error: %v", err) + } + + if err := v.Validate(context.Background(), "node-2"); err == nil { + t.Fatalf("expected node-2 to be rejected without allow-list entry") + } +} + +func TestNodeValidatorAllowlistCIDRWithLookup(t *testing.T) { + _, network, _ := net.ParseCIDR("10.0.0.0/24") + v := &nodeValidator{ + allowHosts: make(map[string]struct{}), + allowCIDRs: []*net.IPNet{network}, + hasAllowlist: true, + resolver: stubResolver{ + ips: []net.IP{net.ParseIP("10.0.0.5")}, + }, + } + + if err := v.Validate(context.Background(), "worker.local"); err != nil { + t.Fatalf("expected worker.local to resolve into allowed CIDR: %v", err) + } +} + +func TestNodeValidatorClusterCaching(t *testing.T) { + current := time.Now() + fetches := 0 + + v := &nodeValidator{ + clusterEnabled: true, + clusterFetcher: func() ([]string, error) { + fetches++ + return []string{"10.0.0.9"}, nil + }, + cacheTTL: nodeValidatorCacheTTL, + clock: func() time.Time { + return current + }, + } + + if err := v.Validate(context.Background(), "10.0.0.9"); err != nil { + t.Fatalf("expected node to be allowed via cluster membership: %v", err) + } + if fetches != 1 { + t.Fatalf("expected initial cluster fetch, got %d", fetches) + } + + current = current.Add(30 * time.Second) + if err := v.Validate(context.Background(), "10.0.0.9"); err != nil { + t.Fatalf("expected cached cluster membership to allow node: %v", err) + } + if fetches != 1 { + t.Fatalf("expected cache hit to avoid new fetch, got %d fetches", fetches) + } + + current = current.Add(nodeValidatorCacheTTL + time.Second) + if err := v.Validate(context.Background(), "10.0.0.9"); err != nil { + t.Fatalf("expected refreshed cluster membership to allow node: %v", err) + } + if fetches != 2 { + t.Fatalf("expected cache expiry to trigger new fetch, got %d", fetches) + } +} + +func TestNodeValidatorStrictNoSources(t *testing.T) { + v := &nodeValidator{ + strict: true, + } + + if err := v.Validate(context.Background(), "node-1"); err == nil { + t.Fatalf("expected strict mode without sources to reject nodes") + } +}