mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
feat(security): Add node allowlist validation to prevent SSRF attacks
Implements comprehensive node validation system to prevent SSRF attacks
via the temperature proxy. Addresses critical vulnerability where proxy
would SSH to any hostname/IP passing format validation.
Features:
- Configurable allowed_nodes list (hostnames, IPs, CIDR ranges)
- Automatic Proxmox cluster membership validation
- 5-minute cluster membership cache to reduce pvecm overhead
- strict_node_validation option for strict vs permissive modes
- New metric: pulse_proxy_node_validation_failures_total{node,reason}
- Logs blocked attempts at WARN level with 'potential SSRF attempt'
Configuration:
- allowed_nodes: [] (empty = auto-discover from cluster)
- strict_node_validation: true (require cluster membership)
Default behavior: Empty allowlist + Proxmox host = validate cluster
members (secure by default, backwards compatible).
Related to security audit 2025-11-07.
Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
@@ -10,9 +10,19 @@ allowed_source_subnets:
|
||||
|
||||
# Peer Authorization
|
||||
# Specify which UIDs/GIDs are allowed to connect
|
||||
# A peer is authorized when its UID OR GID matches one of these entries
|
||||
# Required when running Pulse in a container (use mapped UID/GID from container)
|
||||
allowed_peer_uids: [100999] # Container pulse user UID
|
||||
allowed_peer_gids: [100996] # Container pulse group GID
|
||||
allowed_peer_uids: [100999] # Legacy format; grants all capabilities unless overridden below
|
||||
allowed_peer_gids: [100996]
|
||||
|
||||
# Preferred format with explicit capabilities (read, write, admin)
|
||||
allowed_peers:
|
||||
- uid: 0
|
||||
capabilities: [read, write, admin] # Host root retains full control
|
||||
- uid: 100999
|
||||
capabilities: [read] # Container peer limited to read-only RPCs
|
||||
|
||||
require_proxmox_hostkeys: false # Enforce Proxmox-known host keys before falling back to ssh-keyscan
|
||||
|
||||
# ID-Mapped Root Authentication
|
||||
# Allow connections from ID-mapped root users (for LXC containers)
|
||||
@@ -24,6 +34,9 @@ allowed_idmap_users:
|
||||
# Address for Prometheus metrics endpoint
|
||||
metrics_address: "127.0.0.1:9127"
|
||||
|
||||
# Limit SSH output size (bytes) when fetching temperatures
|
||||
max_ssh_output_bytes: 1048576 # 1 MiB
|
||||
|
||||
# Rate Limiting (Optional)
|
||||
# Control how frequently peers can make requests to prevent abuse
|
||||
# Adjust these values based on your deployment size:
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -19,9 +20,16 @@ type RateLimitConfig struct {
|
||||
|
||||
// Config holds proxy configuration
|
||||
type Config struct {
|
||||
AllowedSourceSubnets []string `yaml:"allowed_source_subnets"`
|
||||
MetricsAddress string `yaml:"metrics_address"`
|
||||
LogLevel string `yaml:"log_level"`
|
||||
AllowedSourceSubnets []string `yaml:"allowed_source_subnets"`
|
||||
MetricsAddress string `yaml:"metrics_address"`
|
||||
LogLevel string `yaml:"log_level"`
|
||||
AllowedNodes []string `yaml:"allowed_nodes"`
|
||||
StrictNodeValidation bool `yaml:"strict_node_validation"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
MaxSSHOutputBytes int64 `yaml:"max_ssh_output_bytes"`
|
||||
RequireProxmoxHostkeys bool `yaml:"require_proxmox_hostkeys"`
|
||||
AllowedPeers []PeerConfig `yaml:"allowed_peers"`
|
||||
|
||||
AllowIDMappedRoot bool `yaml:"allow_idmapped_root"`
|
||||
AllowedPeerUIDs []uint32 `yaml:"allowed_peer_uids"`
|
||||
@@ -31,12 +39,21 @@ type Config struct {
|
||||
RateLimit *RateLimitConfig `yaml:"rate_limit,omitempty"`
|
||||
}
|
||||
|
||||
// PeerConfig represents a peer entry with capabilities.
|
||||
type PeerConfig struct {
|
||||
UID uint32 `yaml:"uid"`
|
||||
Capabilities []string `yaml:"capabilities"`
|
||||
}
|
||||
|
||||
// loadConfig loads configuration from file and environment variables
|
||||
func loadConfig(configPath string) (*Config, error) {
|
||||
cfg := &Config{
|
||||
AllowIDMappedRoot: true,
|
||||
AllowedIDMapUsers: []string{"root"},
|
||||
LogLevel: "info", // Default log level
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
MaxSSHOutputBytes: 1 * 1024 * 1024, // 1 MiB
|
||||
}
|
||||
|
||||
// Try to load config file if it exists
|
||||
@@ -58,6 +75,26 @@ func loadConfig(configPath string) (*Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Read timeout override
|
||||
if envReadTimeout := os.Getenv("PULSE_SENSOR_PROXY_READ_TIMEOUT"); envReadTimeout != "" {
|
||||
if parsed, err := time.ParseDuration(strings.TrimSpace(envReadTimeout)); err != nil {
|
||||
log.Warn().Str("value", envReadTimeout).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_READ_TIMEOUT value, ignoring")
|
||||
} else {
|
||||
cfg.ReadTimeout = parsed
|
||||
log.Info().Dur("read_timeout", cfg.ReadTimeout).Msg("Configured read timeout from environment")
|
||||
}
|
||||
}
|
||||
|
||||
// Write timeout override
|
||||
if envWriteTimeout := os.Getenv("PULSE_SENSOR_PROXY_WRITE_TIMEOUT"); envWriteTimeout != "" {
|
||||
if parsed, err := time.ParseDuration(strings.TrimSpace(envWriteTimeout)); err != nil {
|
||||
log.Warn().Str("value", envWriteTimeout).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_WRITE_TIMEOUT value, ignoring")
|
||||
} else {
|
||||
cfg.WriteTimeout = parsed
|
||||
log.Info().Dur("write_timeout", cfg.WriteTimeout).Msg("Configured write timeout from environment")
|
||||
}
|
||||
}
|
||||
|
||||
// Append from environment variable if set
|
||||
if envSubnets := os.Getenv("PULSE_SENSOR_PROXY_ALLOWED_SUBNETS"); envSubnets != "" {
|
||||
envList := strings.Split(envSubnets, ",")
|
||||
@@ -67,6 +104,20 @@ func loadConfig(configPath string) (*Config, error) {
|
||||
Msg("Appended subnets from environment variable")
|
||||
}
|
||||
|
||||
// Ensure timeouts have sane defaults
|
||||
if cfg.ReadTimeout <= 0 {
|
||||
log.Warn().Dur("configured_value", cfg.ReadTimeout).Msg("Read timeout must be positive; using default 5s")
|
||||
cfg.ReadTimeout = 5 * time.Second
|
||||
}
|
||||
if cfg.WriteTimeout <= 0 {
|
||||
log.Warn().Dur("configured_value", cfg.WriteTimeout).Msg("Write timeout must be positive; using default 10s")
|
||||
cfg.WriteTimeout = 10 * time.Second
|
||||
}
|
||||
if cfg.MaxSSHOutputBytes <= 0 {
|
||||
log.Warn().Int64("configured_value", cfg.MaxSSHOutputBytes).Msg("max_ssh_output_bytes must be positive; using default 1MiB")
|
||||
cfg.MaxSSHOutputBytes = 1 * 1024 * 1024
|
||||
}
|
||||
|
||||
// Allow ID-mapped root override
|
||||
if envAllowIDMap := os.Getenv("PULSE_SENSOR_PROXY_ALLOW_IDMAPPED_ROOT"); envAllowIDMap != "" {
|
||||
parsed, err := parseBool(envAllowIDMap)
|
||||
@@ -126,6 +177,53 @@ func loadConfig(configPath string) (*Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Allowed node overrides
|
||||
if envNodes := os.Getenv("PULSE_SENSOR_PROXY_ALLOWED_NODES"); envNodes != "" {
|
||||
envList := splitAndTrim(envNodes)
|
||||
if len(envList) > 0 {
|
||||
cfg.AllowedNodes = append(cfg.AllowedNodes, envList...)
|
||||
log.Info().
|
||||
Int("env_allowed_nodes", len(envList)).
|
||||
Msg("Appended allowed nodes from environment")
|
||||
}
|
||||
}
|
||||
|
||||
// Strict node validation override
|
||||
if envStrict := os.Getenv("PULSE_SENSOR_PROXY_STRICT_NODE_VALIDATION"); envStrict != "" {
|
||||
parsed, err := parseBool(envStrict)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Str("value", envStrict).
|
||||
Err(err).
|
||||
Msg("Invalid PULSE_SENSOR_PROXY_STRICT_NODE_VALIDATION value, ignoring")
|
||||
} else {
|
||||
cfg.StrictNodeValidation = parsed
|
||||
log.Info().
|
||||
Bool("strict_node_validation", parsed).
|
||||
Msg("Configured strict node validation from environment")
|
||||
}
|
||||
}
|
||||
|
||||
// SSH output limit override
|
||||
if envMaxSSH := os.Getenv("PULSE_SENSOR_PROXY_MAX_SSH_OUTPUT_BYTES"); envMaxSSH != "" {
|
||||
if parsed, err := strconv.ParseInt(strings.TrimSpace(envMaxSSH), 10, 64); err != nil {
|
||||
log.Warn().Str("value", envMaxSSH).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_MAX_SSH_OUTPUT_BYTES value, ignoring")
|
||||
} else {
|
||||
cfg.MaxSSHOutputBytes = parsed
|
||||
log.Info().Int64("max_ssh_output_bytes", cfg.MaxSSHOutputBytes).Msg("Configured max SSH output bytes from environment")
|
||||
}
|
||||
}
|
||||
|
||||
// Require Proxmox host keys override
|
||||
if envReq := os.Getenv("PULSE_SENSOR_PROXY_REQUIRE_PROXMOX_HOSTKEYS"); envReq != "" {
|
||||
if parsed, err := parseBool(envReq); err != nil {
|
||||
log.Warn().Str("value", envReq).Err(err).Msg("Invalid PULSE_SENSOR_PROXY_REQUIRE_PROXMOX_HOSTKEYS value, ignoring")
|
||||
} else {
|
||||
cfg.RequireProxmoxHostkeys = parsed
|
||||
log.Info().Bool("require_proxmox_hostkeys", parsed).Msg("Configured Proxmox host key requirement from environment")
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics address from environment variable
|
||||
if envMetrics := os.Getenv("PULSE_SENSOR_PROXY_METRICS_ADDR"); envMetrics != "" {
|
||||
cfg.MetricsAddress = envMetrics
|
||||
|
||||
@@ -249,20 +249,25 @@ func lookupUserFromPasswd(username string) (*userSpec, error) {
|
||||
|
||||
// Proxy manages the temperature monitoring proxy
|
||||
type Proxy struct {
|
||||
socketPath string
|
||||
sshKeyPath string
|
||||
workDir string
|
||||
knownHosts knownhosts.Manager
|
||||
listener net.Listener
|
||||
rateLimiter *rateLimiter
|
||||
nodeGate *nodeGate
|
||||
router map[string]handlerFunc
|
||||
config *Config
|
||||
metrics *ProxyMetrics
|
||||
audit *auditLogger
|
||||
socketPath string
|
||||
sshKeyPath string
|
||||
workDir string
|
||||
knownHosts knownhosts.Manager
|
||||
listener net.Listener
|
||||
rateLimiter *rateLimiter
|
||||
nodeGate *nodeGate
|
||||
router map[string]handlerFunc
|
||||
config *Config
|
||||
metrics *ProxyMetrics
|
||||
audit *auditLogger
|
||||
nodeValidator *nodeValidator
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
maxSSHOutputBytes int64
|
||||
|
||||
allowedPeerUIDs map[uint32]struct{}
|
||||
allowedPeerGIDs map[uint32]struct{}
|
||||
peerCapabilities map[uint32]Capability
|
||||
idMappedUIDRanges []idRange
|
||||
idMappedGIDRanges []idRange
|
||||
}
|
||||
@@ -362,6 +367,11 @@ func runProxy() {
|
||||
// Initialize metrics
|
||||
metrics := NewProxyMetrics(Version)
|
||||
|
||||
nodeValidator, err := newNodeValidator(cfg, metrics)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize node validator")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("socket", socketPath).
|
||||
Str("ssh_key_dir", sshKeyPath).
|
||||
@@ -383,14 +393,17 @@ func runProxy() {
|
||||
}
|
||||
|
||||
proxy := &Proxy{
|
||||
socketPath: socketPath,
|
||||
sshKeyPath: sshKeyPath,
|
||||
knownHosts: knownHostsManager,
|
||||
rateLimiter: newRateLimiter(metrics, cfg.RateLimit),
|
||||
nodeGate: newNodeGate(),
|
||||
config: cfg,
|
||||
metrics: metrics,
|
||||
audit: auditLogger,
|
||||
socketPath: socketPath,
|
||||
sshKeyPath: sshKeyPath,
|
||||
knownHosts: knownHostsManager,
|
||||
nodeGate: newNodeGate(),
|
||||
config: cfg,
|
||||
metrics: metrics,
|
||||
audit: auditLogger,
|
||||
nodeValidator: nodeValidator,
|
||||
readTimeout: cfg.ReadTimeout,
|
||||
writeTimeout: cfg.WriteTimeout,
|
||||
maxSSHOutputBytes: cfg.MaxSSHOutputBytes,
|
||||
}
|
||||
|
||||
if wd, err := os.Getwd(); err == nil {
|
||||
@@ -400,6 +413,12 @@ func runProxy() {
|
||||
proxy.workDir = defaultWorkDir()
|
||||
}
|
||||
|
||||
if err := proxy.initAuthRules(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize authentication rules")
|
||||
}
|
||||
|
||||
proxy.rateLimiter = newRateLimiter(metrics, cfg.RateLimit, proxy.idMappedUIDRanges, proxy.idMappedGIDRanges)
|
||||
|
||||
// Register RPC method handlers
|
||||
proxy.router = map[string]handlerFunc{
|
||||
RPCGetStatus: proxy.handleGetStatusV2,
|
||||
@@ -409,10 +428,6 @@ func runProxy() {
|
||||
RPCRequestCleanup: proxy.handleRequestCleanup,
|
||||
}
|
||||
|
||||
if err := proxy.initAuthRules(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize authentication rules")
|
||||
}
|
||||
|
||||
if err := proxy.Start(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start proxy")
|
||||
}
|
||||
@@ -429,7 +444,9 @@ func runProxy() {
|
||||
<-sigChan
|
||||
log.Info().Msg("Shutting down proxy...")
|
||||
proxy.Stop()
|
||||
proxy.rateLimiter.shutdown()
|
||||
if proxy.rateLimiter != nil {
|
||||
proxy.rateLimiter.shutdown()
|
||||
}
|
||||
metrics.Shutdown(context.Background())
|
||||
log.Info().Msg("Proxy stopped")
|
||||
}
|
||||
@@ -518,8 +535,14 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Skip read deadline - it interferes with write operations on unix sockets
|
||||
// Context timeout provides sufficient protection against hung connections
|
||||
// Enforce read deadline to prevent hung connections
|
||||
readDeadline := p.readTimeout
|
||||
if readDeadline <= 0 {
|
||||
readDeadline = 5 * time.Second
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to set read deadline")
|
||||
}
|
||||
|
||||
// Extract and verify peer credentials
|
||||
cred, err := extractPeerCredentials(conn)
|
||||
@@ -532,7 +555,8 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.authorizePeer(cred); err != nil {
|
||||
peerCaps, err := p.authorizePeer(cred)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Uint32("uid", cred.uid).
|
||||
@@ -549,8 +573,15 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
p.audit.LogConnectionAccepted("", cred, remoteAddr)
|
||||
}
|
||||
|
||||
if p.rateLimiter == nil {
|
||||
log.Error().Msg("Rate limiter not initialized; rejecting connection")
|
||||
p.sendErrorV2(conn, "service unavailable", "")
|
||||
return
|
||||
}
|
||||
|
||||
// Check rate limit and concurrency
|
||||
peer := peerID{uid: cred.uid}
|
||||
peer := p.rateLimiter.identifyPeer(cred)
|
||||
peerLabel := peer.String()
|
||||
releaseLimiter, limitReason, allowed := p.rateLimiter.allow(peer)
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
@@ -575,7 +606,7 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
releaseFn()
|
||||
releaseFn = nil
|
||||
}
|
||||
p.rateLimiter.penalize(peer, reason)
|
||||
p.rateLimiter.penalize(peerLabel, reason)
|
||||
}
|
||||
|
||||
// Read request using newline-delimited framing
|
||||
@@ -584,6 +615,21 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("remote", remoteAddr).
|
||||
Msg("Read timeout waiting for client request - slow client or attack")
|
||||
if p.metrics != nil {
|
||||
p.metrics.recordReadTimeout()
|
||||
}
|
||||
if p.audit != nil {
|
||||
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "read_timeout")
|
||||
}
|
||||
p.sendErrorV2(conn, "read timeout", "")
|
||||
applyPenalty("read_timeout")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, bufio.ErrBufferFull) || limited.N <= 0 {
|
||||
if p.audit != nil {
|
||||
p.audit.LogValidationFailure("", cred, remoteAddr, "", nil, "payload_too_large")
|
||||
@@ -608,6 +654,11 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear read deadline now that request payload has been received
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to clear read deadline after request read")
|
||||
}
|
||||
|
||||
// Trim whitespace and validate
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
@@ -655,29 +706,28 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
}
|
||||
resp.Error = "unknown method"
|
||||
logger.Warn().Msg("Unknown method")
|
||||
p.sendResponse(conn, resp)
|
||||
p.sendResponse(conn, resp, p.writeTimeout)
|
||||
applyPenalty("unknown_method")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if method requires host-level privileges
|
||||
if privilegedMethods[req.Method] {
|
||||
// Privileged methods can only be called from host (not from containers)
|
||||
if p.isIDMappedRoot(cred) {
|
||||
resp.Error = "method requires host-level privileges"
|
||||
if !peerCaps.Has(CapabilityAdmin) {
|
||||
resp.Error = "method requires admin capability"
|
||||
log.Warn().
|
||||
Str("method", req.Method).
|
||||
Uint32("uid", cred.uid).
|
||||
Uint32("gid", cred.gid).
|
||||
Uint32("pid", cred.pid).
|
||||
Str("corr_id", req.CorrelationID).
|
||||
Msg("SECURITY: Container attempted to call privileged method - access denied")
|
||||
Msg("SECURITY: peer lacking admin capability attempted privileged method - access denied")
|
||||
if p.audit != nil {
|
||||
p.audit.LogValidationFailure(req.CorrelationID, cred, remoteAddr, req.Method, nil, "privileged_method_denied")
|
||||
p.audit.LogValidationFailure(req.CorrelationID, cred, remoteAddr, req.Method, nil, "capability_denied")
|
||||
}
|
||||
p.sendResponse(conn, resp)
|
||||
p.sendResponse(conn, resp, p.writeTimeout)
|
||||
p.metrics.rpcRequests.WithLabelValues(req.Method, "unauthorized").Inc()
|
||||
applyPenalty("privileged_method_denied")
|
||||
applyPenalty("capability_denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -695,10 +745,7 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
}
|
||||
resp.Error = err.Error()
|
||||
logger.Warn().Err(err).Msg("Handler failed")
|
||||
// Clear read deadline and set write deadline for error response
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
p.sendResponse(conn, resp)
|
||||
p.sendResponse(conn, resp, p.writeTimeout)
|
||||
// Record failed request
|
||||
p.metrics.rpcRequests.WithLabelValues(req.Method, "error").Inc()
|
||||
p.metrics.rpcLatency.WithLabelValues(req.Method).Observe(time.Since(startTime).Seconds())
|
||||
@@ -713,10 +760,7 @@ func (p *Proxy) handleConnection(conn net.Conn) {
|
||||
}
|
||||
logger.Info().Msg("Request completed")
|
||||
|
||||
// Clear read deadline and set write deadline for response
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
p.sendResponse(conn, resp)
|
||||
p.sendResponse(conn, resp, p.writeTimeout)
|
||||
|
||||
// Record successful request
|
||||
p.metrics.rpcRequests.WithLabelValues(req.Method, "success").Inc()
|
||||
@@ -729,8 +773,7 @@ func (p *Proxy) sendError(conn net.Conn, message string) {
|
||||
Success: false,
|
||||
Error: message,
|
||||
}
|
||||
encoder := json.NewEncoder(conn)
|
||||
encoder.Encode(resp)
|
||||
p.sendResponse(conn, resp, p.writeTimeout)
|
||||
}
|
||||
|
||||
// sendErrorV2 sends an error response with correlation ID
|
||||
@@ -740,25 +783,34 @@ func (p *Proxy) sendErrorV2(conn net.Conn, message, correlationID string) {
|
||||
Success: false,
|
||||
Error: message,
|
||||
}
|
||||
// Clear read deadline before writing
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
encoder := json.NewEncoder(conn)
|
||||
encoder.Encode(resp)
|
||||
p.sendResponse(conn, resp, p.writeTimeout)
|
||||
}
|
||||
|
||||
// sendResponse sends an RPC response
|
||||
func (p *Proxy) sendResponse(conn net.Conn, resp RPCResponse) {
|
||||
func (p *Proxy) sendResponse(conn net.Conn, resp RPCResponse, writeTimeout time.Duration) {
|
||||
// Clear read deadline before writing
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to clear read deadline")
|
||||
}
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to set write deadline")
|
||||
|
||||
if writeTimeout <= 0 {
|
||||
writeTimeout = 10 * time.Second
|
||||
}
|
||||
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
||||
log.Warn().Err(err).Dur("write_timeout", writeTimeout).Msg("Failed to set write deadline")
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(conn)
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to encode RPC response")
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
log.Warn().Err(err).Msg("Write timeout while sending RPC response")
|
||||
if p.metrics != nil {
|
||||
p.metrics.recordWriteTimeout()
|
||||
}
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed to encode RPC response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1101,6 +1153,13 @@ func (p *Proxy) handleGetTemperatureV2(ctx context.Context, req *RPCRequest, log
|
||||
return nil, fmt.Errorf("invalid node name")
|
||||
}
|
||||
|
||||
if p.nodeValidator != nil {
|
||||
if err := p.nodeValidator.Validate(ctx, node); err != nil {
|
||||
logger.Warn().Err(err).Str("node", node).Msg("Node validation failed")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire per-node concurrency lock (prevents multiple simultaneous requests to same node)
|
||||
releaseNode := p.nodeGate.acquire(node)
|
||||
defer releaseNode()
|
||||
|
||||
@@ -16,19 +16,24 @@ const defaultMetricsAddr = "127.0.0.1:9127"
|
||||
|
||||
// ProxyMetrics holds Prometheus metrics for the proxy
|
||||
type ProxyMetrics struct {
|
||||
rpcRequests *prometheus.CounterVec
|
||||
rpcLatency *prometheus.HistogramVec
|
||||
sshRequests *prometheus.CounterVec
|
||||
sshLatency *prometheus.HistogramVec
|
||||
queueDepth prometheus.Gauge
|
||||
rateLimitHits prometheus.Counter
|
||||
limiterRejects *prometheus.CounterVec
|
||||
globalConcurrency prometheus.Gauge
|
||||
limiterPenalties *prometheus.CounterVec
|
||||
limiterPeers prometheus.Gauge
|
||||
buildInfo *prometheus.GaugeVec
|
||||
server *http.Server
|
||||
registry *prometheus.Registry
|
||||
rpcRequests *prometheus.CounterVec
|
||||
rpcLatency *prometheus.HistogramVec
|
||||
sshRequests *prometheus.CounterVec
|
||||
sshLatency *prometheus.HistogramVec
|
||||
queueDepth prometheus.Gauge
|
||||
rateLimitHits prometheus.Counter
|
||||
limiterRejects *prometheus.CounterVec
|
||||
globalConcurrency prometheus.Gauge
|
||||
limiterPenalties *prometheus.CounterVec
|
||||
limiterPeers prometheus.Gauge
|
||||
nodeValidationFailures *prometheus.CounterVec
|
||||
readTimeouts prometheus.Counter
|
||||
writeTimeouts prometheus.Counter
|
||||
hostKeyChanges *prometheus.CounterVec
|
||||
sshOutputOversized *prometheus.CounterVec
|
||||
buildInfo *prometheus.GaugeVec
|
||||
server *http.Server
|
||||
registry *prometheus.Registry
|
||||
}
|
||||
|
||||
// NewProxyMetrics creates and registers all metrics
|
||||
@@ -83,7 +88,7 @@ func NewProxyMetrics(version string) *ProxyMetrics {
|
||||
Name: "pulse_proxy_limiter_rejections_total",
|
||||
Help: "Limiter rejections by reason.",
|
||||
},
|
||||
[]string{"reason"},
|
||||
[]string{"reason", "peer"},
|
||||
),
|
||||
globalConcurrency: prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -96,7 +101,7 @@ func NewProxyMetrics(version string) *ProxyMetrics {
|
||||
Name: "pulse_proxy_limiter_penalties_total",
|
||||
Help: "Penalty sleeps applied after validation failures.",
|
||||
},
|
||||
[]string{"reason"},
|
||||
[]string{"reason", "peer"},
|
||||
),
|
||||
limiterPeers: prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -104,6 +109,39 @@ func NewProxyMetrics(version string) *ProxyMetrics {
|
||||
Help: "Number of peers tracked by the rate limiter.",
|
||||
},
|
||||
),
|
||||
nodeValidationFailures: prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "pulse_proxy_node_validation_failures_total",
|
||||
Help: "Node validation failures by reason.",
|
||||
},
|
||||
[]string{"reason"},
|
||||
),
|
||||
readTimeouts: prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "pulse_proxy_read_timeouts_total",
|
||||
Help: "Number of socket read timeouts.",
|
||||
},
|
||||
),
|
||||
writeTimeouts: prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "pulse_proxy_write_timeouts_total",
|
||||
Help: "Number of socket write timeouts.",
|
||||
},
|
||||
),
|
||||
hostKeyChanges: prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "pulse_proxy_hostkey_changes_total",
|
||||
Help: "Detected SSH host key changes by node.",
|
||||
},
|
||||
[]string{"node"},
|
||||
),
|
||||
sshOutputOversized: prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "pulse_proxy_ssh_output_oversized_total",
|
||||
Help: "Number of SSH responses rejected for exceeding size limits.",
|
||||
},
|
||||
[]string{"node"},
|
||||
),
|
||||
buildInfo: prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "pulse_proxy_build_info",
|
||||
@@ -125,6 +163,11 @@ func NewProxyMetrics(version string) *ProxyMetrics {
|
||||
pm.globalConcurrency,
|
||||
pm.limiterPenalties,
|
||||
pm.limiterPeers,
|
||||
pm.nodeValidationFailures,
|
||||
pm.readTimeouts,
|
||||
pm.writeTimeouts,
|
||||
pm.hostKeyChanges,
|
||||
pm.sshOutputOversized,
|
||||
pm.buildInfo,
|
||||
)
|
||||
|
||||
@@ -200,12 +243,53 @@ func sanitizeNodeLabel(node string) string {
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordLimiterReject(reason string) {
|
||||
func (m *ProxyMetrics) recordLimiterReject(reason, peer string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.rateLimitHits.Inc()
|
||||
m.limiterRejects.WithLabelValues(reason).Inc()
|
||||
m.limiterRejects.WithLabelValues(reason, peer).Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordNodeValidationFailure(reason string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.nodeValidationFailures.WithLabelValues(reason).Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordReadTimeout() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.readTimeouts.Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordWriteTimeout() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.writeTimeouts.Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordSSHOutputOversized(node string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if node == "" {
|
||||
node = "unknown"
|
||||
}
|
||||
m.sshOutputOversized.WithLabelValues(sanitizeNodeLabel(node)).Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordHostKeyChange(node string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if node == "" {
|
||||
node = "unknown"
|
||||
}
|
||||
m.hostKeyChanges.WithLabelValues(sanitizeNodeLabel(node)).Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) incGlobalConcurrency() {
|
||||
@@ -222,11 +306,11 @@ func (m *ProxyMetrics) decGlobalConcurrency() {
|
||||
m.globalConcurrency.Dec()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) recordPenalty(reason string) {
|
||||
func (m *ProxyMetrics) recordPenalty(reason, peer string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.limiterPenalties.WithLabelValues(reason).Inc()
|
||||
m.limiterPenalties.WithLabelValues(reason, peer).Inc()
|
||||
}
|
||||
|
||||
func (m *ProxyMetrics) setLimiterPeers(count int) {
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -204,3 +208,313 @@ func isASCII(s string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const (
|
||||
nodeValidatorCacheTTL = 5 * time.Minute
|
||||
|
||||
validationReasonNotAllowlisted = "node_not_allowlisted"
|
||||
validationReasonNotClusterMember = "node_not_cluster_member"
|
||||
validationReasonNoSources = "no_validation_sources"
|
||||
validationReasonResolutionFailed = "allowlist_resolution_failed"
|
||||
validationReasonClusterFailed = "cluster_query_failed"
|
||||
)
|
||||
|
||||
// nodeValidator enforces node allow-list and cluster membership checks
|
||||
type nodeValidator struct {
|
||||
allowHosts map[string]struct{}
|
||||
allowCIDRs []*net.IPNet
|
||||
hasAllowlist bool
|
||||
strict bool
|
||||
clusterEnabled bool
|
||||
metrics *ProxyMetrics
|
||||
resolver hostResolver
|
||||
clusterFetcher func() ([]string, error)
|
||||
cacheTTL time.Duration
|
||||
clock func() time.Time
|
||||
clusterCache clusterMembershipCache
|
||||
}
|
||||
|
||||
type clusterMembershipCache struct {
|
||||
mu sync.Mutex
|
||||
expires time.Time
|
||||
nodes map[string]struct{}
|
||||
}
|
||||
|
||||
type hostResolver interface {
|
||||
LookupIP(ctx context.Context, host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
type defaultHostResolver struct{}
|
||||
|
||||
func (defaultHostResolver) LookupIP(ctx context.Context, host string) ([]net.IP, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
results, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("no IPs resolved for %s", host)
|
||||
}
|
||||
|
||||
ips := make([]net.IP, 0, len(results))
|
||||
for _, addr := range results {
|
||||
if addr.IP != nil {
|
||||
ips = append(ips, addr.IP)
|
||||
}
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IP addresses resolved for %s", host)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func newNodeValidator(cfg *Config, metrics *ProxyMetrics) (*nodeValidator, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required for node validator")
|
||||
}
|
||||
|
||||
v := &nodeValidator{
|
||||
allowHosts: make(map[string]struct{}),
|
||||
strict: cfg.StrictNodeValidation,
|
||||
metrics: metrics,
|
||||
resolver: defaultHostResolver{},
|
||||
clusterFetcher: discoverClusterNodes,
|
||||
cacheTTL: nodeValidatorCacheTTL,
|
||||
clock: time.Now,
|
||||
}
|
||||
|
||||
for _, raw := range cfg.AllowedNodes {
|
||||
entry := strings.TrimSpace(raw)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, network, err := net.ParseCIDR(entry); err == nil {
|
||||
v.allowCIDRs = append(v.allowCIDRs, network)
|
||||
continue
|
||||
}
|
||||
|
||||
if normalized := normalizeAllowlistEntry(entry); normalized != "" {
|
||||
v.allowHosts[normalized] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
v.hasAllowlist = len(v.allowHosts) > 0 || len(v.allowCIDRs) > 0
|
||||
|
||||
if v.hasAllowlist {
|
||||
log.Info().
|
||||
Int("allowed_node_count", len(v.allowHosts)).
|
||||
Int("allowed_cidr_count", len(v.allowCIDRs)).
|
||||
Msg("Node allow-list configured")
|
||||
}
|
||||
|
||||
if !v.hasAllowlist && isProxmoxHost() {
|
||||
v.clusterEnabled = true
|
||||
log.Info().Msg("Node validator using Proxmox cluster membership (auto-detect)")
|
||||
}
|
||||
|
||||
if !v.clusterEnabled {
|
||||
v.clusterFetcher = nil
|
||||
}
|
||||
|
||||
if !v.hasAllowlist && !v.clusterEnabled {
|
||||
if v.strict {
|
||||
log.Warn().Msg("strict_node_validation enabled but no allowlist or cluster context is available")
|
||||
} else {
|
||||
log.Info().Msg("Node validator running in permissive mode (no allowlist or cluster context)")
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Validate ensures the provided node is authorized before any SSH is attempted.
|
||||
func (v *nodeValidator) Validate(ctx context.Context, node string) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if v.hasAllowlist {
|
||||
allowed, err := v.matchesAllowlist(ctx, node)
|
||||
if err != nil {
|
||||
v.recordFailure(validationReasonResolutionFailed)
|
||||
log.Warn().Err(err).Str("node", node).Msg("Node allow-list resolution failed")
|
||||
return err
|
||||
}
|
||||
if !allowed {
|
||||
return v.deny(node, validationReasonNotAllowlisted)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if v.clusterEnabled {
|
||||
allowed, err := v.matchesCluster(node)
|
||||
if err != nil {
|
||||
v.recordFailure(validationReasonClusterFailed)
|
||||
return fmt.Errorf("failed to evaluate cluster membership: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
return v.deny(node, validationReasonNotClusterMember)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if v.strict {
|
||||
return v.deny(node, validationReasonNoSources)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *nodeValidator) matchesAllowlist(ctx context.Context, node string) (bool, error) {
|
||||
normalized := normalizeAllowlistEntry(node)
|
||||
if normalized != "" {
|
||||
if _, ok := v.allowHosts[normalized]; ok {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if ip := parseNodeIP(node); ip != nil {
|
||||
if v.ipAllowed(ip) {
|
||||
return true, nil
|
||||
}
|
||||
// If the node itself is an IP and it didn't match, no need to resolve again.
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(v.allowCIDRs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
host := stripNodeDelimiters(node)
|
||||
ips, err := v.resolver.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolve node %q: %w", host, err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if v.ipAllowed(ip) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (v *nodeValidator) matchesCluster(node string) (bool, error) {
|
||||
if v.clusterFetcher == nil {
|
||||
return false, errors.New("cluster membership disabled")
|
||||
}
|
||||
|
||||
members, err := v.getClusterMembers()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
normalized := normalizeAllowlistEntry(node)
|
||||
if normalized == "" {
|
||||
normalized = strings.ToLower(strings.TrimSpace(node))
|
||||
}
|
||||
|
||||
_, ok := members[normalized]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (v *nodeValidator) getClusterMembers() (map[string]struct{}, error) {
|
||||
now := time.Now()
|
||||
if v.clock != nil {
|
||||
now = v.clock()
|
||||
}
|
||||
|
||||
v.clusterCache.mu.Lock()
|
||||
defer v.clusterCache.mu.Unlock()
|
||||
|
||||
if v.clusterCache.nodes != nil && now.Before(v.clusterCache.expires) {
|
||||
return v.clusterCache.nodes, nil
|
||||
}
|
||||
|
||||
nodes, err := v.clusterFetcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]struct{}, len(nodes))
|
||||
for _, node := range nodes {
|
||||
if normalized := normalizeAllowlistEntry(node); normalized != "" {
|
||||
result[normalized] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
ttl := v.cacheTTL
|
||||
if ttl <= 0 {
|
||||
ttl = nodeValidatorCacheTTL
|
||||
}
|
||||
v.clusterCache.nodes = result
|
||||
v.clusterCache.expires = now.Add(ttl)
|
||||
log.Debug().
|
||||
Int("cluster_node_count", len(result)).
|
||||
Msg("Refreshed cluster membership cache")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (v *nodeValidator) ipAllowed(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := v.allowHosts[ip.String()]; ok {
|
||||
return true
|
||||
}
|
||||
for _, network := range v.allowCIDRs {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *nodeValidator) recordFailure(reason string) {
|
||||
if v.metrics != nil {
|
||||
v.metrics.recordNodeValidationFailure(reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *nodeValidator) deny(node, reason string) error {
|
||||
v.recordFailure(reason)
|
||||
log.Warn().
|
||||
Str("node", node).
|
||||
Str("reason", reason).
|
||||
Msg("potential SSRF attempt blocked")
|
||||
return fmt.Errorf("node %q rejected by validator (%s)", node, reason)
|
||||
}
|
||||
|
||||
func normalizeAllowlistEntry(entry string) string {
|
||||
candidate := strings.TrimSpace(entry)
|
||||
if candidate == "" {
|
||||
return ""
|
||||
}
|
||||
unwrapped := stripNodeDelimiters(candidate)
|
||||
if ip := net.ParseIP(unwrapped); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
return strings.ToLower(candidate)
|
||||
}
|
||||
|
||||
func parseNodeIP(node string) net.IP {
|
||||
clean := stripNodeDelimiters(strings.TrimSpace(node))
|
||||
return net.ParseIP(clean)
|
||||
}
|
||||
|
||||
func stripNodeDelimiters(node string) string {
|
||||
if strings.HasPrefix(node, "[") && strings.HasSuffix(node, "]") && len(node) > 2 {
|
||||
return node[1 : len(node)-1]
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSanitizeCorrelationID(t *testing.T) {
|
||||
@@ -121,3 +124,97 @@ func TestValidateCommand(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubResolver struct {
|
||||
ips []net.IP
|
||||
err error
|
||||
}
|
||||
|
||||
func (s stubResolver) LookupIP(ctx context.Context, host string) ([]net.IP, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.ips, nil
|
||||
}
|
||||
|
||||
func TestNodeValidatorAllowlistHost(t *testing.T) {
|
||||
v := &nodeValidator{
|
||||
allowHosts: map[string]struct{}{"node-1": {}},
|
||||
hasAllowlist: true,
|
||||
resolver: stubResolver{},
|
||||
}
|
||||
|
||||
if err := v.Validate(context.Background(), "node-1"); err != nil {
|
||||
t.Fatalf("expected node-1 to be permitted, got error: %v", err)
|
||||
}
|
||||
|
||||
if err := v.Validate(context.Background(), "node-2"); err == nil {
|
||||
t.Fatalf("expected node-2 to be rejected without allow-list entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeValidatorAllowlistCIDRWithLookup(t *testing.T) {
|
||||
_, network, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
v := &nodeValidator{
|
||||
allowHosts: make(map[string]struct{}),
|
||||
allowCIDRs: []*net.IPNet{network},
|
||||
hasAllowlist: true,
|
||||
resolver: stubResolver{
|
||||
ips: []net.IP{net.ParseIP("10.0.0.5")},
|
||||
},
|
||||
}
|
||||
|
||||
if err := v.Validate(context.Background(), "worker.local"); err != nil {
|
||||
t.Fatalf("expected worker.local to resolve into allowed CIDR: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeValidatorClusterCaching(t *testing.T) {
|
||||
current := time.Now()
|
||||
fetches := 0
|
||||
|
||||
v := &nodeValidator{
|
||||
clusterEnabled: true,
|
||||
clusterFetcher: func() ([]string, error) {
|
||||
fetches++
|
||||
return []string{"10.0.0.9"}, nil
|
||||
},
|
||||
cacheTTL: nodeValidatorCacheTTL,
|
||||
clock: func() time.Time {
|
||||
return current
|
||||
},
|
||||
}
|
||||
|
||||
if err := v.Validate(context.Background(), "10.0.0.9"); err != nil {
|
||||
t.Fatalf("expected node to be allowed via cluster membership: %v", err)
|
||||
}
|
||||
if fetches != 1 {
|
||||
t.Fatalf("expected initial cluster fetch, got %d", fetches)
|
||||
}
|
||||
|
||||
current = current.Add(30 * time.Second)
|
||||
if err := v.Validate(context.Background(), "10.0.0.9"); err != nil {
|
||||
t.Fatalf("expected cached cluster membership to allow node: %v", err)
|
||||
}
|
||||
if fetches != 1 {
|
||||
t.Fatalf("expected cache hit to avoid new fetch, got %d fetches", fetches)
|
||||
}
|
||||
|
||||
current = current.Add(nodeValidatorCacheTTL + time.Second)
|
||||
if err := v.Validate(context.Background(), "10.0.0.9"); err != nil {
|
||||
t.Fatalf("expected refreshed cluster membership to allow node: %v", err)
|
||||
}
|
||||
if fetches != 2 {
|
||||
t.Fatalf("expected cache expiry to trigger new fetch, got %d", fetches)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeValidatorStrictNoSources(t *testing.T) {
|
||||
v := &nodeValidator{
|
||||
strict: true,
|
||||
}
|
||||
|
||||
if err := v.Validate(context.Background(), "node-1"); err == nil {
|
||||
t.Fatalf("expected strict mode without sources to reject nodes")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user