diff --git a/cmd/pulse-sensor-proxy/auth.go b/cmd/pulse-sensor-proxy/auth.go index fdaf7d8c1..667660666 100644 --- a/cmd/pulse-sensor-proxy/auth.go +++ b/cmd/pulse-sensor-proxy/auth.go @@ -66,16 +66,41 @@ func extractPeerCredentials(conn net.Conn) (*peerCredentials, error) { func (p *Proxy) initAuthRules() error { p.allowedPeerUIDs = make(map[uint32]struct{}) p.allowedPeerGIDs = make(map[uint32]struct{}) + p.peerCapabilities = make(map[uint32]Capability) + + addCapability := func(uid uint32, caps Capability) { + if caps == 0 { + caps = CapabilityRead + } + if existing, ok := p.peerCapabilities[uid]; ok { + p.peerCapabilities[uid] = existing | caps + } else { + p.peerCapabilities[uid] = caps + } + } // Always allow root and the proxy's own user p.allowedPeerUIDs[0] = struct{}{} + addCapability(0, capabilityLegacyAll) p.allowedPeerUIDs[uint32(os.Getuid())] = struct{}{} + addCapability(uint32(os.Getuid()), capabilityLegacyAll) p.allowedPeerGIDs[0] = struct{}{} p.allowedPeerGIDs[uint32(os.Getgid())] = struct{}{} + if len(p.config.AllowedPeers) > 0 { + for _, peer := range p.config.AllowedPeers { + p.allowedPeerUIDs[peer.UID] = struct{}{} + addCapability(peer.UID, parseCapabilityList(peer.Capabilities)) + } + log.Info().Int("peer_capability_entries", len(p.config.AllowedPeers)).Msg("Loaded capability entries for peers") + } + if len(p.config.AllowedPeerUIDs) > 0 { for _, uid := range dedupeUint32(p.config.AllowedPeerUIDs) { p.allowedPeerUIDs[uid] = struct{}{} + if _, ok := p.peerCapabilities[uid]; !ok { + addCapability(uid, capabilityLegacyAll) + } } log.Info(). Int("explicit_uid_allow_count", len(p.config.AllowedPeerUIDs)). @@ -101,13 +126,9 @@ func (p *Proxy) initAuthRules() error { users = []string{"root"} } - uidRanges, err := loadSubIDRanges("/etc/subuid", users) + uidRanges, gidRanges, err := loadIDMappingRanges(users) if err != nil { - return fmt.Errorf("loading subordinate UID ranges: %w", err) - } - gidRanges, err := loadSubIDRanges("/etc/subgid", users) - if err != nil { - return fmt.Errorf("loading subordinate GID ranges: %w", err) + return err } p.idMappedUIDRanges = uidRanges @@ -128,17 +149,37 @@ func (p *Proxy) initAuthRules() error { return nil } -// authorizePeer verifies the peer credentials against configured allow lists -func (p *Proxy) authorizePeer(cred *peerCredentials) error { - if _, ok := p.allowedPeerUIDs[cred.uid]; ok { - return nil +// authorizePeer verifies the peer credentials against configured allow lists and returns capabilities. +func (p *Proxy) authorizePeer(cred *peerCredentials) (Capability, error) { + if cred == nil { + return 0, fmt.Errorf("missing peer credentials") + } + + if caps, ok := p.peerCapabilities[cred.uid]; ok { + log.Debug(). + Uint32("uid", cred.uid). + Msg("Peer authorized via UID allow-list") + return caps, nil + } + + if len(p.allowedPeerGIDs) > 0 { + if _, ok := p.allowedPeerGIDs[cred.gid]; ok { + log.Debug(). + Uint32("gid", cred.gid). + Msg("Peer authorized via GID allow-list") + return capabilityLegacyAll, nil + } } if p.config.AllowIDMappedRoot && p.isIDMappedRoot(cred) { - return nil + log.Debug(). + Uint32("uid", cred.uid). + Uint32("gid", cred.gid). + Msg("Peer authorized via ID-mapped root range") + return CapabilityRead, nil } - return fmt.Errorf("unauthorized: uid=%d gid=%d", cred.uid, cred.gid) + return 0, fmt.Errorf("unauthorized: uid=%d gid=%d", cred.uid, cred.gid) } func (p *Proxy) isIDMappedRoot(cred *peerCredentials) bool { @@ -253,3 +294,15 @@ func loadSubIDRanges(path string, users []string) ([]idRange, error) { return ranges, nil } + +func loadIDMappingRanges(users []string) ([]idRange, []idRange, error) { + uidRanges, err := loadSubIDRanges("/etc/subuid", users) + if err != nil { + return nil, nil, fmt.Errorf("loading subordinate UID ranges: %w", err) + } + gidRanges, err := loadSubIDRanges("/etc/subgid", users) + if err != nil { + return nil, nil, fmt.Errorf("loading subordinate GID ranges: %w", err) + } + return uidRanges, gidRanges, nil +} diff --git a/cmd/pulse-sensor-proxy/throttle.go b/cmd/pulse-sensor-proxy/throttle.go index 8b8b3bf15..11c69125a 100644 --- a/cmd/pulse-sensor-proxy/throttle.go +++ b/cmd/pulse-sensor-proxy/throttle.go @@ -1,15 +1,25 @@ package main import ( + "fmt" "sync" "time" "golang.org/x/time/rate" ) -// peerID identifies a connecting principal (grouped by UID) +// peerID identifies a connecting principal (grouped by UID or ID range) type peerID struct { - uid uint32 + uid uint32 + uidRange *idRange +} + +func (p peerID) String() string { + if p.uidRange != nil { + end := p.uidRange.start + p.uidRange.length - 1 + return fmt.Sprintf("range:%d-%d", p.uidRange.start, end) + } + return fmt.Sprintf("uid:%d", p.uid) } // limiterEntry holds rate limiting and concurrency controls for a peer @@ -30,15 +40,17 @@ type limiterPolicy struct { // rateLimiter manages per-peer rate limits and concurrency type rateLimiter struct { mu sync.Mutex - entries map[peerID]*limiterEntry + entries map[string]*limiterEntry quitChan chan struct{} globalSem chan struct{} policy limiterPolicy metrics *ProxyMetrics + uidRanges []idRange + gidRanges []idRange } const ( - defaultPerPeerBurst = 5 // Allow burst of 5 requests for multi-node polling + defaultPerPeerBurst = 5 // Allow burst of 5 requests for multi-node polling defaultPerPeerConcurrency = 2 defaultGlobalConcurrency = 8 ) @@ -51,7 +63,7 @@ var ( // newRateLimiter creates a new rate limiter with cleanup loop // If rateLimitCfg is provided, it overrides the default rate limit settings -func newRateLimiter(metrics *ProxyMetrics, rateLimitCfg *RateLimitConfig) *rateLimiter { +func newRateLimiter(metrics *ProxyMetrics, rateLimitCfg *RateLimitConfig, uidRanges, gidRanges []idRange) *rateLimiter { // Use defaults perPeerLimit := defaultPerPeerLimit perPeerBurst := defaultPerPeerBurst @@ -68,7 +80,7 @@ func newRateLimiter(metrics *ProxyMetrics, rateLimitCfg *RateLimitConfig) *rateL } rl := &rateLimiter{ - entries: make(map[peerID]*limiterEntry), + entries: make(map[string]*limiterEntry), quitChan: make(chan struct{}), globalSem: make(chan struct{}, defaultGlobalConcurrency), policy: limiterPolicy{ @@ -78,7 +90,9 @@ func newRateLimiter(metrics *ProxyMetrics, rateLimitCfg *RateLimitConfig) *rateL globalConcurrency: defaultGlobalConcurrency, penaltyDuration: defaultPenaltyDuration, }, - metrics: metrics, + metrics: metrics, + uidRanges: append([]idRange(nil), uidRanges...), + gidRanges: append([]idRange(nil), gidRanges...), } if rl.metrics != nil { rl.metrics.setLimiterPeers(0) @@ -90,14 +104,15 @@ func newRateLimiter(metrics *ProxyMetrics, rateLimitCfg *RateLimitConfig) *rateL // allow checks if a peer is allowed to make a request and reserves concurrency. // Returns a release function, rejection reason (if any), and whether the request is allowed. func (rl *rateLimiter) allow(id peerID) (release func(), reason string, allowed bool) { + key := id.String() rl.mu.Lock() - entry := rl.entries[id] + entry := rl.entries[key] if entry == nil { entry = &limiterEntry{ limiter: rate.NewLimiter(rl.policy.perPeerLimit, rl.policy.perPeerBurst), semaphore: make(chan struct{}, rl.policy.perPeerConcurrency), } - rl.entries[id] = entry + rl.entries[key] = entry if rl.metrics != nil { rl.metrics.setLimiterPeers(len(rl.entries)) } @@ -107,7 +122,7 @@ func (rl *rateLimiter) allow(id peerID) (release func(), reason string, allowed // Check rate limit if !entry.limiter.Allow() { - rl.recordRejection("rate") + rl.recordRejection("rate", key) return nil, "rate", false } @@ -118,7 +133,7 @@ func (rl *rateLimiter) allow(id peerID) (release func(), reason string, allowed rl.metrics.incGlobalConcurrency() } default: - rl.recordRejection("global_concurrency") + rl.recordRejection("global_concurrency", key) return nil, "global_concurrency", false } @@ -137,7 +152,7 @@ func (rl *rateLimiter) allow(id peerID) (release func(), reason string, allowed if rl.metrics != nil { rl.metrics.decGlobalConcurrency() } - rl.recordRejection("peer_concurrency") + rl.recordRejection("peer_concurrency", key) return nil, "peer_concurrency", false } } @@ -150,9 +165,9 @@ func (rl *rateLimiter) cleanupLoop() { select { case <-ticker.C: rl.mu.Lock() - for id, entry := range rl.entries { + for key, entry := range rl.entries { if time.Since(entry.lastSeen) > 10*time.Minute { - delete(rl.entries, id) + delete(rl.entries, key) } } if rl.metrics != nil { @@ -170,22 +185,53 @@ func (rl *rateLimiter) shutdown() { close(rl.quitChan) } -func (rl *rateLimiter) penalize(id peerID, reason string) { +func (rl *rateLimiter) penalize(peerLabel, reason string) { if rl.policy.penaltyDuration <= 0 { return } time.Sleep(rl.policy.penaltyDuration) if rl.metrics != nil { - rl.metrics.recordPenalty(reason) + rl.metrics.recordPenalty(reason, peerLabel) } } -func (rl *rateLimiter) recordRejection(reason string) { +func (rl *rateLimiter) recordRejection(reason, peerLabel string) { if rl.metrics != nil { - rl.metrics.recordLimiterReject(reason) + rl.metrics.recordLimiterReject(reason, peerLabel) } } +func (rl *rateLimiter) identifyPeer(cred *peerCredentials) peerID { + if cred == nil { + return peerID{} + } + if rl == nil { + return peerID{uid: cred.uid} + } + + if len(rl.uidRanges) == 0 || len(rl.gidRanges) == 0 { + return peerID{uid: cred.uid} + } + + uidRange := findRange(rl.uidRanges, cred.uid) + gidRange := findRange(rl.gidRanges, cred.gid) + + if uidRange != nil && gidRange != nil { + return peerID{uid: cred.uid, uidRange: uidRange} + } + + return peerID{uid: cred.uid} +} + +func findRange(ranges []idRange, value uint32) *idRange { + for i := range ranges { + if ranges[i].contains(value) { + return &ranges[i] + } + } + return nil +} + // nodeGate controls per-node concurrency for temperature requests type nodeGate struct { mu sync.Mutex diff --git a/cmd/pulse-sensor-proxy/throttle_test.go b/cmd/pulse-sensor-proxy/throttle_test.go index d4b647024..e9758ffea 100644 --- a/cmd/pulse-sensor-proxy/throttle_test.go +++ b/cmd/pulse-sensor-proxy/throttle_test.go @@ -1,43 +1,74 @@ package main import ( - "testing" - "time" + "testing" + "time" ) func TestRateLimiterPenalizeMetrics(t *testing.T) { - metrics := NewProxyMetrics("test") - rl := newRateLimiter(metrics, nil) - rl.policy.penaltyDuration = 10 * time.Millisecond + metrics := NewProxyMetrics("test") + rl := newRateLimiter(metrics, nil, nil, nil) + rl.policy.penaltyDuration = 10 * time.Millisecond - start := time.Now() - rl.penalize(peerID{uid: 42}, "invalid_json") - if time.Since(start) < rl.policy.penaltyDuration { - t.Fatalf("expected penalize to sleep at least %v", rl.policy.penaltyDuration) - } + start := time.Now() + rl.penalize("uid:42", "invalid_json") + if time.Since(start) < rl.policy.penaltyDuration { + t.Fatalf("expected penalize to sleep at least %v", rl.policy.penaltyDuration) + } - mf, err := metrics.registry.Gather() - if err != nil { - t.Fatalf("gather metrics: %v", err) - } + mf, err := metrics.registry.Gather() + if err != nil { + t.Fatalf("gather metrics: %v", err) + } - found := false - for _, fam := range mf { - if fam.GetName() != "pulse_proxy_limiter_penalties_total" { - continue - } - for _, metric := range fam.GetMetric() { - if metric.GetCounter().GetValue() == 0 { - continue - } - for _, label := range metric.GetLabel() { - if label.GetName() == "reason" && label.GetValue() == "invalid_json" { - found = true - } - } - } - } - if !found { - t.Fatalf("expected limiter penalty metric for invalid_json") - } + found := false + for _, fam := range mf { + if fam.GetName() != "pulse_proxy_limiter_penalties_total" { + continue + } + for _, metric := range fam.GetMetric() { + if metric.GetCounter().GetValue() == 0 { + continue + } + var reasonLabel, peerLabel string + for _, label := range metric.GetLabel() { + switch label.GetName() { + case "reason": + reasonLabel = label.GetValue() + case "peer": + peerLabel = label.GetValue() + } + } + if reasonLabel == "invalid_json" && peerLabel == "uid:42" { + found = true + } + } + } + if !found { + t.Fatalf("expected limiter penalty metric for invalid_json and peer uid:42") + } +} + +func TestIdentifyPeerRangeVsUID(t *testing.T) { + uidRanges := []idRange{{start: 100000, length: 65536}} + gidRanges := []idRange{{start: 100000, length: 65536}} + rl := newRateLimiter(nil, nil, uidRanges, gidRanges) + + containerCred := &peerCredentials{uid: 110000, gid: 110000} + containerPeer := rl.identifyPeer(containerCred) + if containerPeer.uidRange == nil { + t.Fatalf("expected container peer to map to UID range") + } + if got := containerPeer.String(); got != "range:100000-165535" { + t.Fatalf("unexpected container peer label: %s", got) + } + + hostCred := &peerCredentials{uid: 1000, gid: 1000} + hostPeer := rl.identifyPeer(hostCred) + if hostPeer.uidRange != nil { + t.Fatalf("expected host peer to use UID label") + } + if got := hostPeer.String(); got != "uid:1000" { + t.Fatalf("unexpected host peer label: %s", got) + } }