Files
Pulse/internal/dockeragent/agent.go
rcourtman ea335546fc feat: improve legacy agent detection and migration UX
Add seamless migration path from legacy agents to unified agent:

- Add AgentType field to report payloads (unified vs legacy detection)
- Update server to detect legacy agents by type instead of version
- Add UI banner showing upgrade command when legacy agents are detected
- Add deprecation notice to install-host-agent.ps1
- Create install-docker-agent.sh stub that redirects to unified installer

Legacy agents (pulse-host-agent, pulse-docker-agent) now show a "Legacy"
badge in the UI with a one-click copy command to upgrade to the unified
agent.
2025-11-25 23:26:22 +00:00

1788 lines
47 KiB
Go

package dockeragent
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"math/big"
"net/http"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/docker/docker/api/types"
containertypes "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
systemtypes "github.com/docker/docker/api/types/system"
"github.com/docker/docker/client"
"github.com/rcourtman/pulse-go-rewrite/internal/hostmetrics"
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
agentsdocker "github.com/rcourtman/pulse-go-rewrite/pkg/agents/docker"
"github.com/rs/zerolog"
)
// TargetConfig describes a single Pulse backend the agent should report to.
type TargetConfig struct {
URL string
Token string
InsecureSkipVerify bool
}
// Config describes runtime configuration for the Docker agent.
type Config struct {
PulseURL string
APIToken string
Interval time.Duration
HostnameOverride string
AgentID string
AgentType string // "unified" when running as part of pulse-agent, empty for standalone
InsecureSkipVerify bool
DisableAutoUpdate bool
Targets []TargetConfig
ContainerStates []string
SwarmScope string
Runtime string
IncludeServices bool
IncludeTasks bool
IncludeContainers bool
CollectDiskMetrics bool
LogLevel zerolog.Level
Logger *zerolog.Logger
}
var allowedContainerStates = map[string]string{
"created": "created",
"restarting": "restarting",
"running": "running",
"removing": "removing",
"paused": "paused",
"exited": "exited",
"dead": "dead",
"stopped": "exited",
}
type RuntimeKind string
const (
RuntimeAuto RuntimeKind = "auto"
RuntimeDocker RuntimeKind = "docker"
RuntimePodman RuntimeKind = "podman"
)
// Agent collects Docker metrics and posts them to Pulse.
type Agent struct {
cfg Config
docker *client.Client
daemonHost string
runtime RuntimeKind
runtimeVer string
supportsSwarm bool
httpClients map[bool]*http.Client
logger zerolog.Logger
machineID string
hostName string
cpuCount int
targets []TargetConfig
allowedStates map[string]struct{}
stateFilters []string
hostID string
prevContainerCPU map[string]cpuSample
preCPUStatsFailures int
}
// ErrStopRequested indicates the agent should terminate gracefully after acknowledging a stop command.
var ErrStopRequested = errors.New("docker host stop requested")
type cpuSample struct {
totalUsage uint64
systemUsage uint64
onlineCPUs uint32
read time.Time
}
// New creates a new Docker agent instance.
func New(cfg Config) (*Agent, error) {
targets, err := normalizeTargets(cfg.Targets)
if err != nil {
return nil, err
}
if len(targets) == 0 {
url := strings.TrimSpace(cfg.PulseURL)
token := strings.TrimSpace(cfg.APIToken)
if url == "" || token == "" {
return nil, errors.New("at least one Pulse target is required")
}
targets, err = normalizeTargets([]TargetConfig{{
URL: url,
Token: token,
InsecureSkipVerify: cfg.InsecureSkipVerify,
}})
if err != nil {
return nil, err
}
}
cfg.Targets = targets
cfg.PulseURL = targets[0].URL
cfg.APIToken = targets[0].Token
cfg.InsecureSkipVerify = targets[0].InsecureSkipVerify
stateFilters, err := normalizeContainerStates(cfg.ContainerStates)
if err != nil {
return nil, err
}
cfg.ContainerStates = stateFilters
scope, err := normalizeSwarmScope(cfg.SwarmScope)
if err != nil {
return nil, err
}
cfg.SwarmScope = scope
if !cfg.IncludeContainers && !cfg.IncludeServices && !cfg.IncludeTasks {
cfg.IncludeContainers = true
cfg.IncludeServices = true
cfg.IncludeTasks = true
}
logger := cfg.Logger
if zerolog.GlobalLevel() == zerolog.DebugLevel && cfg.LogLevel != zerolog.DebugLevel {
zerolog.SetGlobalLevel(cfg.LogLevel)
}
if logger == nil {
defaultLogger := zerolog.New(os.Stdout).Level(cfg.LogLevel).With().Timestamp().Str("component", "pulse-docker-agent").Logger()
logger = &defaultLogger
} else {
scoped := logger.With().Str("component", "pulse-docker-agent").Logger()
logger = &scoped
}
runtimePref, err := normalizeRuntime(cfg.Runtime)
if err != nil {
return nil, err
}
dockerClient, info, runtimeKind, err := connectRuntime(runtimePref, logger)
if err != nil {
return nil, err
}
cfg.Runtime = string(runtimeKind)
if runtimeKind == RuntimePodman {
if cfg.IncludeServices {
logger.Warn().Msg("Podman runtime detected; disabling Swarm service collection")
}
if cfg.IncludeTasks {
logger.Warn().Msg("Podman runtime detected; disabling Swarm task collection")
}
cfg.IncludeServices = false
cfg.IncludeTasks = false
}
logger.Info().
Str("runtime", string(runtimeKind)).
Str("daemon_host", dockerClient.DaemonHost()).
Str("version", info.ServerVersion).
Msg("Connected to container runtime")
hasSecure := false
hasInsecure := false
for _, target := range cfg.Targets {
if target.InsecureSkipVerify {
hasInsecure = true
} else {
hasSecure = true
}
}
httpClients := make(map[bool]*http.Client, 2)
if hasSecure {
httpClients[false] = newHTTPClient(false)
}
if hasInsecure {
httpClients[true] = newHTTPClient(true)
}
machineID, _ := readMachineID()
hostName := cfg.HostnameOverride
if hostName == "" {
if h, err := os.Hostname(); err == nil {
hostName = h
}
}
agent := &Agent{
cfg: cfg,
docker: dockerClient,
daemonHost: dockerClient.DaemonHost(),
runtime: runtimeKind,
runtimeVer: info.ServerVersion,
supportsSwarm: runtimeKind == RuntimeDocker,
httpClients: httpClients,
logger: *logger,
machineID: machineID,
hostName: hostName,
targets: cfg.Targets,
allowedStates: make(map[string]struct{}, len(stateFilters)),
stateFilters: stateFilters,
prevContainerCPU: make(map[string]cpuSample),
}
for _, state := range stateFilters {
agent.allowedStates[state] = struct{}{}
}
return agent, nil
}
func normalizeTargets(raw []TargetConfig) ([]TargetConfig, error) {
if len(raw) == 0 {
return nil, nil
}
normalized := make([]TargetConfig, 0, len(raw))
seen := make(map[string]struct{}, len(raw))
for _, target := range raw {
url := strings.TrimSpace(target.URL)
token := strings.TrimSpace(target.Token)
if url == "" && token == "" {
continue
}
if url == "" {
return nil, errors.New("pulse target URL is required")
}
if token == "" {
return nil, fmt.Errorf("pulse target %s is missing API token", url)
}
url = strings.TrimRight(url, "/")
key := fmt.Sprintf("%s|%s|%t", url, token, target.InsecureSkipVerify)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
normalized = append(normalized, TargetConfig{
URL: url,
Token: token,
InsecureSkipVerify: target.InsecureSkipVerify,
})
}
return normalized, nil
}
func normalizeContainerStates(raw []string) ([]string, error) {
if len(raw) == 0 {
return nil, nil
}
normalized := make([]string, 0, len(raw))
seen := make(map[string]struct{}, len(raw))
for _, value := range raw {
state := strings.ToLower(strings.TrimSpace(value))
if state == "" {
continue
}
canonical, ok := allowedContainerStates[state]
if !ok {
return nil, fmt.Errorf("unsupported container state %q", value)
}
if _, exists := seen[canonical]; exists {
continue
}
seen[canonical] = struct{}{}
normalized = append(normalized, canonical)
}
return normalized, nil
}
func normalizeRuntime(value string) (RuntimeKind, error) {
runtime := strings.ToLower(strings.TrimSpace(value))
switch runtime {
case "", string(RuntimeAuto), "default":
return RuntimeAuto, nil
case string(RuntimeDocker):
return RuntimeDocker, nil
case string(RuntimePodman):
return RuntimePodman, nil
default:
return "", fmt.Errorf("unsupported runtime %q: must be auto, docker, or podman", value)
}
}
type runtimeCandidate struct {
host string
label string
applyDockerEnv bool
}
func connectRuntime(preference RuntimeKind, logger *zerolog.Logger) (*client.Client, systemtypes.Info, RuntimeKind, error) {
candidates := buildRuntimeCandidates(preference)
var attempts []string
for _, candidate := range candidates {
opts := []client.Opt{client.WithAPIVersionNegotiation()}
if candidate.applyDockerEnv {
opts = append(opts, client.FromEnv)
}
if candidate.host != "" {
opts = append(opts, client.WithHost(candidate.host))
}
cli, info, err := tryRuntimeCandidate(opts)
if err != nil {
attempts = append(attempts, fmt.Sprintf("%s: %v", candidate.label, err))
continue
}
endpoint := cli.DaemonHost()
runtime := detectRuntime(info, endpoint, preference)
if preference != RuntimeAuto && runtime != preference {
attempts = append(attempts, fmt.Sprintf("%s: detected %s runtime", candidate.label, runtime))
_ = cli.Close()
continue
}
if logger != nil {
logger.Debug().Str("host", endpoint).Str("runtime", string(runtime)).Msg("Connected to container runtime")
}
return cli, info, runtime, nil
}
if len(attempts) == 0 {
return nil, systemtypes.Info{}, RuntimeAuto, errors.New("no container runtime endpoints to try")
}
return nil, systemtypes.Info{}, RuntimeAuto, fmt.Errorf("failed to connect to container runtime: %s", strings.Join(attempts, "; "))
}
func tryRuntimeCandidate(opts []client.Opt) (*client.Client, systemtypes.Info, error) {
cli, err := client.NewClientWithOpts(opts...)
if err != nil {
return nil, systemtypes.Info{}, err
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
info, err := cli.Info(ctx)
if err != nil {
_ = cli.Close()
return nil, systemtypes.Info{}, err
}
return cli, info, nil
}
func buildRuntimeCandidates(preference RuntimeKind) []runtimeCandidate {
candidates := make([]runtimeCandidate, 0, 6)
seen := make(map[string]struct{})
add := func(candidate runtimeCandidate) {
hostKey := candidate.host
if hostKey == "" {
hostKey = "__default__"
}
if _, ok := seen[hostKey]; ok {
return
}
seen[hostKey] = struct{}{}
candidates = append(candidates, candidate)
}
add(runtimeCandidate{
label: "environment defaults",
applyDockerEnv: true,
})
if host := utils.GetenvTrim("DOCKER_HOST"); host != "" {
add(runtimeCandidate{
host: host,
label: "DOCKER_HOST",
applyDockerEnv: true,
})
}
if host := utils.GetenvTrim("CONTAINER_HOST"); host != "" {
add(runtimeCandidate{
host: host,
label: "CONTAINER_HOST",
})
}
if host := utils.GetenvTrim("PODMAN_HOST"); host != "" {
add(runtimeCandidate{
host: host,
label: "PODMAN_HOST",
})
}
if preference == RuntimePodman || preference == RuntimeAuto {
rootless := fmt.Sprintf("unix:///run/user/%d/podman/podman.sock", os.Getuid())
add(runtimeCandidate{
host: rootless,
label: "podman rootless socket",
})
add(runtimeCandidate{
host: "unix:///run/podman/podman.sock",
label: "podman system socket",
})
}
if preference == RuntimeDocker || preference == RuntimeAuto {
add(runtimeCandidate{
host: "unix:///var/run/docker.sock",
label: "default docker socket",
applyDockerEnv: true,
})
}
return candidates
}
func detectRuntime(info systemtypes.Info, endpoint string, preference RuntimeKind) RuntimeKind {
if preference == RuntimePodman {
return RuntimePodman
}
lowerEndpoint := strings.ToLower(endpoint)
if strings.Contains(lowerEndpoint, "podman") || strings.Contains(lowerEndpoint, "libpod") {
return RuntimePodman
}
if strings.Contains(strings.ToLower(info.InitBinary), "podman") {
return RuntimePodman
}
if strings.Contains(strings.ToLower(info.ServerVersion), "podman") {
return RuntimePodman
}
for _, pair := range info.DriverStatus {
if strings.Contains(strings.ToLower(pair[0]), "podman") || strings.Contains(strings.ToLower(pair[1]), "podman") {
return RuntimePodman
}
}
for _, option := range info.SecurityOptions {
if strings.Contains(strings.ToLower(option), "podman") {
return RuntimePodman
}
}
if preference == RuntimeDocker {
return RuntimeDocker
}
return RuntimeDocker
}
// Run starts the collection loop until the context is cancelled.
func (a *Agent) Run(ctx context.Context) error {
interval := a.cfg.Interval
if interval <= 0 {
interval = 30 * time.Second
a.cfg.Interval = interval
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
const (
updateInterval = 24 * time.Hour
startupJitterWindow = 2 * time.Minute
recurringJitterWindow = 5 * time.Minute
)
initialDelay := 5*time.Second + randomDuration(startupJitterWindow)
updateTimer := time.NewTimer(initialDelay)
defer func() {
if !updateTimer.Stop() {
select {
case <-updateTimer.C:
default:
}
}
}()
if err := a.collectOnce(ctx); err != nil {
if errors.Is(err, ErrStopRequested) {
return nil
}
a.logger.Error().Err(err).Msg("Failed to send initial report")
}
for {
select {
case <-ctx.Done():
if !updateTimer.Stop() {
select {
case <-updateTimer.C:
default:
}
}
return ctx.Err()
case <-ticker.C:
if err := a.collectOnce(ctx); err != nil {
if errors.Is(err, ErrStopRequested) {
return nil
}
a.logger.Error().Err(err).Msg("Failed to send docker report")
}
case <-updateTimer.C:
go a.checkForUpdates(ctx)
nextDelay := updateInterval + randomDuration(recurringJitterWindow)
if nextDelay <= 0 {
nextDelay = updateInterval
}
updateTimer.Reset(nextDelay)
}
}
}
func (a *Agent) collectOnce(ctx context.Context) error {
report, err := a.buildReport(ctx)
if err != nil {
return err
}
return a.sendReport(ctx, report)
}
func (a *Agent) buildReport(ctx context.Context) (agentsdocker.Report, error) {
info, err := a.docker.Info(ctx)
if err != nil {
return agentsdocker.Report{}, fmt.Errorf("failed to query docker info: %w", err)
}
a.runtimeVer = info.ServerVersion
if a.daemonHost == "" {
a.daemonHost = a.docker.DaemonHost()
}
newRuntime := detectRuntime(info, a.daemonHost, RuntimeAuto)
if newRuntime != a.runtime {
if a.runtime != "" {
a.logger.Info().
Str("runtime_previous", string(a.runtime)).
Str("runtime_current", string(newRuntime)).
Msg("Detected container runtime change")
}
a.runtime = newRuntime
a.supportsSwarm = newRuntime == RuntimeDocker
if newRuntime == RuntimePodman {
if a.cfg.IncludeServices {
a.logger.Warn().Msg("Podman runtime detected during report; disabling Swarm service collection")
}
if a.cfg.IncludeTasks {
a.logger.Warn().Msg("Podman runtime detected during report; disabling Swarm task collection")
}
a.cfg.IncludeServices = false
a.cfg.IncludeTasks = false
}
a.cfg.Runtime = string(newRuntime)
}
a.cpuCount = info.NCPU
agentID := a.cfg.AgentID
if agentID == "" {
agentID = info.ID
}
if agentID == "" {
agentID = a.machineID
}
if agentID == "" {
agentID = a.hostName
}
a.hostID = agentID
hostName := a.hostName
if hostName == "" {
hostName = info.Name
}
uptime := readSystemUptime()
metricsCtx, metricsCancel := context.WithTimeout(ctx, 10*time.Second)
snapshot, err := hostmetrics.Collect(metricsCtx)
metricsCancel()
if err != nil {
return agentsdocker.Report{}, fmt.Errorf("collect host metrics: %w", err)
}
collectContainers := a.cfg.IncludeContainers
if !collectContainers && (a.cfg.IncludeServices || a.cfg.IncludeTasks) && !info.Swarm.ControlAvailable {
collectContainers = true
}
var containers []agentsdocker.Container
if collectContainers {
var err error
containers, err = a.collectContainers(ctx)
if err != nil {
return agentsdocker.Report{}, err
}
}
services, tasks, swarmInfo := a.collectSwarmData(ctx, info, containers)
report := agentsdocker.Report{
Agent: agentsdocker.AgentInfo{
ID: agentID,
Version: Version,
Type: a.cfg.AgentType,
IntervalSeconds: int(a.cfg.Interval / time.Second),
},
Host: agentsdocker.HostInfo{
Hostname: hostName,
Name: info.Name,
MachineID: a.machineID,
OS: info.OperatingSystem,
Runtime: string(a.runtime),
RuntimeVersion: a.runtimeVer,
KernelVersion: info.KernelVersion,
Architecture: info.Architecture,
DockerVersion: info.ServerVersion,
TotalCPU: info.NCPU,
TotalMemoryBytes: info.MemTotal,
UptimeSeconds: uptime,
CPUUsagePercent: safeFloat(snapshot.CPUUsagePercent),
LoadAverage: append([]float64(nil), snapshot.LoadAverage...),
Memory: snapshot.Memory,
Disks: append([]agentsdocker.Disk(nil), snapshot.Disks...),
Network: append([]agentsdocker.NetworkInterface(nil), snapshot.Network...),
},
Timestamp: time.Now().UTC(),
}
if swarmInfo != nil {
report.Host.Swarm = swarmInfo
}
if a.cfg.IncludeContainers {
report.Containers = containers
}
if a.cfg.IncludeServices && len(services) > 0 {
report.Services = services
}
if a.cfg.IncludeTasks && len(tasks) > 0 {
report.Tasks = tasks
}
if report.Agent.IntervalSeconds <= 0 {
report.Agent.IntervalSeconds = int(30 * time.Second / time.Second)
}
return report, nil
}
func (a *Agent) collectContainers(ctx context.Context) ([]agentsdocker.Container, error) {
options := containertypes.ListOptions{All: true}
if len(a.stateFilters) > 0 {
filterArgs := filters.NewArgs()
for _, state := range a.stateFilters {
filterArgs.Add("status", state)
}
options.Filters = filterArgs
}
list, err := a.docker.ContainerList(ctx, options)
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
containers := make([]agentsdocker.Container, 0, len(list))
active := make(map[string]struct{}, len(list))
for _, summary := range list {
if len(a.allowedStates) > 0 {
if _, ok := a.allowedStates[strings.ToLower(summary.State)]; !ok {
continue
}
}
active[summary.ID] = struct{}{}
container, err := a.collectContainer(ctx, summary)
if err != nil {
a.logger.Warn().Str("container", strings.Join(summary.Names, ",")).Err(err).Msg("Failed to collect container stats")
continue
}
containers = append(containers, container)
}
a.pruneStaleCPUSamples(active)
return containers, nil
}
func (a *Agent) pruneStaleCPUSamples(active map[string]struct{}) {
if len(a.prevContainerCPU) == 0 {
return
}
for id := range a.prevContainerCPU {
if _, ok := active[id]; !ok {
delete(a.prevContainerCPU, id)
}
}
}
func (a *Agent) collectContainer(ctx context.Context, summary types.Container) (agentsdocker.Container, error) {
const perContainerTimeout = 15 * time.Second
containerCtx, cancel := context.WithTimeout(ctx, perContainerTimeout)
defer cancel()
requestSize := a.cfg.CollectDiskMetrics
inspect, _, err := a.docker.ContainerInspectWithRaw(containerCtx, summary.ID, requestSize)
if err != nil {
return agentsdocker.Container{}, fmt.Errorf("inspect: %w", err)
}
var (
cpuPercent float64
memUsage int64
memLimit int64
memPercent float64
blockIO *agentsdocker.ContainerBlockIO
)
if inspect.State.Running || inspect.State.Paused {
statsResp, err := a.docker.ContainerStatsOneShot(containerCtx, summary.ID)
if err != nil {
return agentsdocker.Container{}, fmt.Errorf("stats: %w", err)
}
defer statsResp.Body.Close()
var stats containertypes.StatsResponse
if err := json.NewDecoder(statsResp.Body).Decode(&stats); err != nil {
return agentsdocker.Container{}, fmt.Errorf("decode stats: %w", err)
}
cpuPercent = a.calculateContainerCPUPercent(summary.ID, stats)
memUsage, memLimit, memPercent = calculateMemoryUsage(stats)
blockIO = summarizeBlockIO(stats)
} else {
delete(a.prevContainerCPU, summary.ID)
}
createdAt := time.Unix(summary.Created, 0)
startedAt := parseTime(inspect.State.StartedAt)
finishedAt := parseTime(inspect.State.FinishedAt)
uptimeSeconds := int64(0)
if !startedAt.IsZero() && inspect.State.Running {
uptimeSeconds = int64(time.Since(startedAt).Seconds())
if uptimeSeconds < 0 {
uptimeSeconds = 0
}
}
health := ""
if inspect.State.Health != nil {
health = inspect.State.Health.Status
}
ports := make([]agentsdocker.ContainerPort, len(summary.Ports))
for i, port := range summary.Ports {
ports[i] = agentsdocker.ContainerPort{
PrivatePort: int(port.PrivatePort),
PublicPort: int(port.PublicPort),
Protocol: port.Type,
IP: port.IP,
}
}
labels := make(map[string]string, len(summary.Labels))
for k, v := range summary.Labels {
labels[k] = v
}
networks := make([]agentsdocker.ContainerNetwork, 0)
if inspect.NetworkSettings != nil {
for name, cfg := range inspect.NetworkSettings.Networks {
networks = append(networks, agentsdocker.ContainerNetwork{
Name: name,
IPv4: cfg.IPAddress,
IPv6: cfg.GlobalIPv6Address,
})
}
}
var startedPtr, finishedPtr *time.Time
if !startedAt.IsZero() {
started := startedAt
startedPtr = &started
}
if !finishedAt.IsZero() && !inspect.State.Running {
finished := finishedAt
finishedPtr = &finished
}
var writableLayerBytes int64
if inspect.SizeRw != nil {
writableLayerBytes = *inspect.SizeRw
}
var rootFsBytes int64
if inspect.SizeRootFs != nil {
rootFsBytes = *inspect.SizeRootFs
}
var mounts []agentsdocker.ContainerMount
if len(inspect.Mounts) > 0 {
mounts = make([]agentsdocker.ContainerMount, 0, len(inspect.Mounts))
for _, mount := range inspect.Mounts {
mounts = append(mounts, agentsdocker.ContainerMount{
Type: string(mount.Type),
Source: mount.Source,
Destination: mount.Destination,
Mode: mount.Mode,
RW: mount.RW,
Propagation: string(mount.Propagation),
Name: mount.Name,
Driver: mount.Driver,
})
}
}
container := agentsdocker.Container{
ID: summary.ID,
Name: trimLeadingSlash(summary.Names),
Image: summary.Image,
CreatedAt: createdAt,
State: summary.State,
Status: summary.Status,
Health: health,
CPUPercent: cpuPercent,
MemoryUsageBytes: memUsage,
MemoryLimitBytes: memLimit,
MemoryPercent: memPercent,
UptimeSeconds: uptimeSeconds,
RestartCount: inspect.RestartCount,
ExitCode: inspect.State.ExitCode,
StartedAt: startedPtr,
FinishedAt: finishedPtr,
Ports: ports,
Labels: labels,
Networks: networks,
WritableLayerBytes: writableLayerBytes,
RootFilesystemBytes: rootFsBytes,
BlockIO: blockIO,
Mounts: mounts,
}
if a.runtime == RuntimePodman {
if meta := extractPodmanMetadata(labels); meta != nil {
container.Podman = meta
}
}
if requestSize {
a.logger.Debug().
Str("container", container.Name).
Int64("writableLayerBytes", writableLayerBytes).
Int64("rootFilesystemBytes", rootFsBytes).
Int("mountCount", len(mounts)).
Msg("Collected container disk metrics")
}
return container, nil
}
func extractPodmanMetadata(labels map[string]string) *agentsdocker.PodmanContainer {
if len(labels) == 0 {
return nil
}
meta := &agentsdocker.PodmanContainer{}
if v := strings.TrimSpace(labels["io.podman.annotations.pod.name"]); v != "" {
meta.PodName = v
}
if v := strings.TrimSpace(labels["io.podman.annotations.pod.id"]); v != "" {
meta.PodID = v
}
if v := strings.TrimSpace(labels["io.podman.annotations.pod.infra"]); v != "" {
if parsed, err := strconv.ParseBool(v); err == nil {
meta.Infra = parsed
} else if strings.EqualFold(v, "yes") || strings.EqualFold(v, "true") {
meta.Infra = true
}
}
if v := strings.TrimSpace(labels["io.podman.compose.project"]); v != "" {
meta.ComposeProject = v
}
if v := strings.TrimSpace(labels["io.podman.compose.service"]); v != "" {
meta.ComposeService = v
}
if v := strings.TrimSpace(labels["io.podman.compose.working_dir"]); v != "" {
meta.ComposeWorkdir = v
}
if v := strings.TrimSpace(labels["io.podman.compose.config-hash"]); v != "" {
meta.ComposeConfig = v
}
if v := strings.TrimSpace(labels["io.containers.autoupdate"]); v != "" {
meta.AutoUpdatePolicy = v
}
if v := strings.TrimSpace(labels["io.containers.autoupdate.restart"]); v != "" {
meta.AutoUpdateRestart = v
}
if v := strings.TrimSpace(labels["io.podman.annotations.userns"]); v != "" {
meta.UserNS = v
} else if v := strings.TrimSpace(labels["io.containers.userns"]); v != "" {
meta.UserNS = v
}
if meta.PodName == "" && meta.PodID == "" && meta.ComposeProject == "" && meta.AutoUpdatePolicy == "" && meta.UserNS == "" && !meta.Infra {
return nil
}
return meta
}
func (a *Agent) sendReport(ctx context.Context, report agentsdocker.Report) error {
payload, err := json.Marshal(report)
if err != nil {
return fmt.Errorf("marshal report: %w", err)
}
var errs []error
containerCount := len(report.Containers)
for _, target := range a.targets {
err := a.sendReportToTarget(ctx, target, payload, containerCount)
if err == nil {
continue
}
if errors.Is(err, ErrStopRequested) {
return ErrStopRequested
}
errs = append(errs, err)
}
if len(errs) > 0 {
return errors.Join(errs...)
}
a.logger.Debug().
Int("containers", containerCount).
Int("targets", len(a.targets)).
Msg("Report sent to Pulse targets")
return nil
}
func (a *Agent) sendReportToTarget(ctx context.Context, target TargetConfig, payload []byte, containerCount int) error {
url := fmt.Sprintf("%s/api/agents/docker/report", target.URL)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("target %s: create request: %w", target.URL, err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-API-Token", target.Token)
req.Header.Set("Authorization", "Bearer "+target.Token)
req.Header.Set("User-Agent", "pulse-docker-agent/"+Version)
client := a.httpClientFor(target)
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("target %s: send report: %w", target.URL, err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
bodyBytes, _ := io.ReadAll(resp.Body)
if hostRemoved := detectHostRemovedError(bodyBytes); hostRemoved != "" {
a.logger.Warn().
Str("hostID", a.hostID).
Str("pulseURL", target.URL).
Str("detail", hostRemoved).
Msg("Pulse rejected docker report because this host was previously removed. Allow the host to re-enroll from the Pulse UI or rerun the installer with a docker:manage token.")
return ErrStopRequested
}
errMsg := strings.TrimSpace(string(bodyBytes))
if errMsg == "" {
errMsg = resp.Status
}
return fmt.Errorf("target %s: pulse responded %s: %s", target.URL, resp.Status, errMsg)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("target %s: read response: %w", target.URL, err)
}
if len(body) == 0 {
return nil
}
var reportResp agentsdocker.ReportResponse
if err := json.Unmarshal(body, &reportResp); err != nil {
a.logger.Warn().Err(err).Str("target", target.URL).Msg("Failed to decode Pulse response")
return nil
}
for _, command := range reportResp.Commands {
err := a.handleCommand(ctx, target, command)
if err == nil {
continue
}
if errors.Is(err, ErrStopRequested) {
return ErrStopRequested
}
return err
}
return nil
}
func (a *Agent) handleCommand(ctx context.Context, target TargetConfig, command agentsdocker.Command) error {
switch strings.ToLower(command.Type) {
case agentsdocker.CommandTypeStop:
return a.handleStopCommand(ctx, target, command)
default:
a.logger.Warn().Str("command", command.Type).Msg("Received unsupported control command")
return nil
}
}
func (a *Agent) handleStopCommand(ctx context.Context, target TargetConfig, command agentsdocker.Command) error {
a.logger.Info().Str("commandID", command.ID).Msg("Received stop command from Pulse")
if err := a.disableSelf(ctx); err != nil {
a.logger.Error().Err(err).Msg("Failed to disable pulse-docker-agent service")
if ackErr := a.sendCommandAck(ctx, target, command.ID, agentsdocker.CommandStatusFailed, err.Error()); ackErr != nil {
a.logger.Error().Err(ackErr).Msg("Failed to send failure acknowledgement to Pulse")
}
return nil
}
if err := a.sendCommandAck(ctx, target, command.ID, agentsdocker.CommandStatusCompleted, "Agent shutting down"); err != nil {
return fmt.Errorf("send stop acknowledgement: %w", err)
}
a.logger.Info().Msg("Stop command acknowledged; terminating agent")
// After sending the acknowledgement, stop the systemd service to prevent restart.
// This is done after the ack to ensure the acknowledgement is sent before the
// process is terminated by systemctl stop.
go func() {
// Small delay to ensure the ack response completes
time.Sleep(1 * time.Second)
stopServiceCtx := context.Background()
if err := stopSystemdService(stopServiceCtx, "pulse-docker-agent"); err != nil {
a.logger.Warn().Err(err).Msg("Failed to stop systemd service, agent will exit normally")
}
}()
return ErrStopRequested
}
func (a *Agent) disableSelf(ctx context.Context) error {
if err := disableSystemdService(ctx, "pulse-docker-agent"); err != nil {
return err
}
// Remove Unraid startup script if present to prevent restart on reboot.
if err := removeFileIfExists("/boot/config/go.d/pulse-docker-agent.sh"); err != nil {
a.logger.Warn().Err(err).Msg("Failed to remove Unraid startup script")
}
// Best-effort log cleanup (ignore errors).
_ = removeFileIfExists("/var/log/pulse-docker-agent.log")
return nil
}
func disableSystemdService(ctx context.Context, service string) error {
if _, err := exec.LookPath("systemctl"); err != nil {
// Not a systemd environment; nothing to do.
return nil
}
cmd := exec.CommandContext(ctx, "systemctl", "disable", service)
output, err := cmd.CombinedOutput()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode := exitErr.ExitCode()
trimmedOutput := strings.TrimSpace(string(output))
lowerOutput := strings.ToLower(trimmedOutput)
if exitCode == 5 || strings.Contains(lowerOutput, "could not be found") || strings.Contains(lowerOutput, "not-found") {
return nil
}
if strings.Contains(lowerOutput, "access denied") || strings.Contains(lowerOutput, "permission denied") {
return fmt.Errorf("systemctl disable %s: access denied. Run 'sudo systemctl disable --now %s' or rerun the installer with sudo so it can install the polkit rule (systemctl output: %s)", service, service, trimmedOutput)
}
}
return fmt.Errorf("systemctl disable %s: %w (%s)", service, err, strings.TrimSpace(string(output)))
}
return nil
}
func stopSystemdService(ctx context.Context, service string) error {
if _, err := exec.LookPath("systemctl"); err != nil {
// Not a systemd environment; nothing to do.
return nil
}
// Stop the service to terminate the current running instance.
// This prevents systemd from restarting the service (services stopped via
// systemctl stop are not restarted even with Restart=always).
cmd := exec.CommandContext(ctx, "systemctl", "stop", service)
output, err := cmd.CombinedOutput()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode := exitErr.ExitCode()
trimmedOutput := strings.TrimSpace(string(output))
lowerOutput := strings.ToLower(trimmedOutput)
// Ignore "not found" errors since the service might already be stopped
if exitCode == 5 || strings.Contains(lowerOutput, "could not be found") || strings.Contains(lowerOutput, "not-found") {
return nil
}
if strings.Contains(lowerOutput, "access denied") || strings.Contains(lowerOutput, "permission denied") {
return fmt.Errorf("systemctl stop %s: access denied. Run 'sudo systemctl stop %s' or rerun the installer with sudo so it can install the polkit rule (systemctl output: %s)", service, service, trimmedOutput)
}
}
return fmt.Errorf("systemctl stop %s: %w (%s)", service, err, strings.TrimSpace(string(output)))
}
return nil
}
func removeFileIfExists(path string) error {
if err := os.Remove(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
return nil
}
func (a *Agent) sendCommandAck(ctx context.Context, target TargetConfig, commandID, status, message string) error {
if a.hostID == "" {
return fmt.Errorf("host identifier unavailable; cannot acknowledge command")
}
ackPayload := agentsdocker.CommandAck{
HostID: a.hostID,
Status: status,
Message: message,
}
body, err := json.Marshal(ackPayload)
if err != nil {
return fmt.Errorf("marshal command acknowledgement: %w", err)
}
url := fmt.Sprintf("%s/api/agents/docker/commands/%s/ack", target.URL, commandID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create acknowledgement request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-API-Token", target.Token)
req.Header.Set("Authorization", "Bearer "+target.Token)
req.Header.Set("User-Agent", "pulse-docker-agent/"+Version)
resp, err := a.httpClientFor(target).Do(req)
if err != nil {
return fmt.Errorf("send acknowledgement: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
bodyBytes, _ := io.ReadAll(resp.Body)
return fmt.Errorf("pulse responded %s: %s", resp.Status, strings.TrimSpace(string(bodyBytes)))
}
return nil
}
func (a *Agent) primaryTarget() TargetConfig {
if len(a.targets) == 0 {
return TargetConfig{}
}
return a.targets[0]
}
func (a *Agent) httpClientFor(target TargetConfig) *http.Client {
if client, ok := a.httpClients[target.InsecureSkipVerify]; ok {
return client
}
if client, ok := a.httpClients[false]; ok {
return client
}
if client, ok := a.httpClients[true]; ok {
return client
}
return newHTTPClient(target.InsecureSkipVerify)
}
func newHTTPClient(insecure bool) *http.Client {
transport := &http.Transport{}
if insecure {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
}
return &http.Client{
Timeout: 15 * time.Second,
Transport: transport,
}
}
func summarizeBlockIO(stats containertypes.StatsResponse) *agentsdocker.ContainerBlockIO {
var readBytes, writeBytes uint64
for _, entry := range stats.BlkioStats.IoServiceBytesRecursive {
switch strings.ToLower(entry.Op) {
case "read":
readBytes += entry.Value
case "write":
writeBytes += entry.Value
}
}
if readBytes == 0 && writeBytes == 0 {
return nil
}
return &agentsdocker.ContainerBlockIO{
ReadBytes: readBytes,
WriteBytes: writeBytes,
}
}
func (a *Agent) calculateContainerCPUPercent(id string, stats containertypes.StatsResponse) float64 {
current := cpuSample{
totalUsage: stats.CPUStats.CPUUsage.TotalUsage,
systemUsage: stats.CPUStats.SystemUsage,
onlineCPUs: stats.CPUStats.OnlineCPUs,
read: stats.Read,
}
// Try to use PreCPUStats if available
percent := calculateCPUPercent(stats, a.cpuCount)
if percent > 0 {
a.prevContainerCPU[id] = current
a.logger.Debug().
Str("container_id", id[:12]).
Float64("cpu_percent", percent).
Msg("CPU calculated from PreCPUStats")
return percent
}
// PreCPUStats not available or invalid, use manual tracking
a.preCPUStatsFailures++
if a.preCPUStatsFailures == 10 {
a.logger.Warn().
Str("runtime", string(a.runtime)).
Msg("PreCPUStats consistently unavailable from Docker API - using manual CPU tracking (this is normal for one-shot stats)")
}
prev, ok := a.prevContainerCPU[id]
if !ok {
// First time seeing this container - store current sample and return 0
// On next collection cycle we'll have a previous sample to compare against
a.prevContainerCPU[id] = current
a.logger.Debug().
Str("container_id", id[:12]).
Uint64("total_usage", current.totalUsage).
Uint64("system_usage", current.systemUsage).
Msg("First CPU sample collected, no previous data for delta calculation")
return 0
}
// We have a previous sample - update it after calculation
a.prevContainerCPU[id] = current
var totalDelta float64
if current.totalUsage >= prev.totalUsage {
totalDelta = float64(current.totalUsage - prev.totalUsage)
} else {
// Counter likely reset (container restart); fall back to current reading.
totalDelta = float64(current.totalUsage)
}
if totalDelta <= 0 {
return 0
}
onlineCPUs := current.onlineCPUs
if onlineCPUs == 0 {
onlineCPUs = prev.onlineCPUs
}
if onlineCPUs == 0 && a.cpuCount > 0 {
onlineCPUs = uint32(a.cpuCount)
}
if onlineCPUs == 0 {
return 0
}
var systemDelta float64
if current.systemUsage >= prev.systemUsage {
systemDelta = float64(current.systemUsage - prev.systemUsage)
} else if current.systemUsage > 0 {
systemDelta = float64(current.systemUsage)
}
if systemDelta > 0 {
cpuPercent := safeFloat((totalDelta / systemDelta) * float64(onlineCPUs) * 100.0)
a.logger.Debug().
Str("container_id", id[:12]).
Float64("cpu_percent", cpuPercent).
Float64("total_delta", totalDelta).
Float64("system_delta", systemDelta).
Uint32("online_cpus", onlineCPUs).
Msg("CPU calculated from system delta")
return cpuPercent
}
// Fall back to time-based calculation
if !prev.read.IsZero() && !current.read.IsZero() {
elapsed := current.read.Sub(prev.read).Seconds()
if elapsed > 0 {
denominator := elapsed * float64(onlineCPUs) * 1e9
if denominator > 0 {
cpuPercent := (totalDelta / denominator) * 100.0
result := safeFloat(cpuPercent)
a.logger.Debug().
Str("container_id", id[:12]).
Float64("cpu_percent", result).
Float64("total_delta", totalDelta).
Float64("elapsed_seconds", elapsed).
Uint32("online_cpus", onlineCPUs).
Msg("CPU calculated from time-based delta")
return result
}
}
}
a.logger.Debug().
Str("container_id", id[:12]).
Float64("total_delta", totalDelta).
Float64("system_delta", systemDelta).
Bool("prev_read_zero", prev.read.IsZero()).
Bool("current_read_zero", current.read.IsZero()).
Msg("CPU calculation failed: no valid delta method available")
return 0
}
func calculateCPUPercent(stats containertypes.StatsResponse, hostCPUs int) float64 {
totalDelta := float64(stats.CPUStats.CPUUsage.TotalUsage - stats.PreCPUStats.CPUUsage.TotalUsage)
systemDelta := float64(stats.CPUStats.SystemUsage - stats.PreCPUStats.SystemUsage)
if totalDelta <= 0 || systemDelta <= 0 {
return 0
}
onlineCPUs := stats.CPUStats.OnlineCPUs
if onlineCPUs == 0 {
onlineCPUs = uint32(len(stats.CPUStats.CPUUsage.PercpuUsage))
}
if onlineCPUs == 0 && hostCPUs > 0 {
onlineCPUs = uint32(hostCPUs)
}
if onlineCPUs == 0 {
return 0
}
return safeFloat((totalDelta / systemDelta) * float64(onlineCPUs) * 100.0)
}
func calculateMemoryUsage(stats containertypes.StatsResponse) (usage int64, limit int64, percent float64) {
usage = int64(stats.MemoryStats.Usage)
if cache, ok := stats.MemoryStats.Stats["cache"]; ok {
usage -= int64(cache)
}
if usage < 0 {
usage = int64(stats.MemoryStats.Usage)
}
limit = int64(stats.MemoryStats.Limit)
if limit > 0 {
percent = (float64(usage) / float64(limit)) * 100.0
}
return usage, limit, safeFloat(percent)
}
func safeFloat(val float64) float64 {
if math.IsNaN(val) || math.IsInf(val, 0) {
return 0
}
return val
}
func parseTime(value string) time.Time {
if value == "" || value == "0001-01-01T00:00:00Z" {
return time.Time{}
}
if t, err := time.Parse(time.RFC3339Nano, value); err == nil {
return t
}
if t, err := time.Parse(time.RFC3339, value); err == nil {
return t
}
return time.Time{}
}
func trimLeadingSlash(names []string) string {
if len(names) == 0 {
return ""
}
name := names[0]
return strings.TrimPrefix(name, "/")
}
func (a *Agent) Close() error {
return a.docker.Close()
}
func readMachineID() (string, error) {
paths := []string{
"/etc/machine-id",
"/var/lib/dbus/machine-id",
}
for _, path := range paths {
data, err := os.ReadFile(path)
if err == nil {
return strings.TrimSpace(string(data)), nil
}
}
return "", errors.New("machine-id not found")
}
func readSystemUptime() int64 {
seconds, err := readProcUptime()
if err != nil {
return 0
}
return int64(seconds)
}
func randomDuration(max time.Duration) time.Duration {
if max <= 0 {
return 0
}
n, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
return 0
}
return time.Duration(n.Int64())
}
func detectHostRemovedError(body []byte) string {
if len(body) == 0 {
return ""
}
var payload struct {
Error string `json:"error"`
Code string `json:"code"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
if strings.ToLower(payload.Code) != "invalid_report" {
return ""
}
if !strings.Contains(strings.ToLower(payload.Error), "was removed") {
return ""
}
return payload.Error
}
// checkForUpdates checks if a newer version is available and performs self-update if needed
func (a *Agent) checkForUpdates(ctx context.Context) {
// Skip updates if disabled via config
if a.cfg.DisableAutoUpdate {
a.logger.Info().Msg("Skipping update check - auto-update disabled")
return
}
// Skip updates in development mode to prevent update loops
if Version == "dev" {
a.logger.Debug().Msg("Skipping update check - running in development mode")
return
}
a.logger.Debug().Msg("Checking for agent updates")
target := a.primaryTarget()
if target.URL == "" {
a.logger.Debug().Msg("Skipping update check - no Pulse target configured")
return
}
// Get current version from server
url := fmt.Sprintf("%s/api/agent/version", target.URL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
a.logger.Warn().Err(err).Msg("Failed to create version check request")
return
}
if target.Token != "" {
req.Header.Set("X-API-Token", target.Token)
req.Header.Set("Authorization", "Bearer "+target.Token)
}
client := a.httpClientFor(target)
resp, err := client.Do(req)
if err != nil {
a.logger.Warn().Err(err).Msg("Failed to check for updates")
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
a.logger.Warn().Int("status", resp.StatusCode).Msg("Version endpoint returned non-200 status")
return
}
var versionResp struct {
Version string `json:"version"`
}
if err := json.NewDecoder(resp.Body).Decode(&versionResp); err != nil {
a.logger.Warn().Err(err).Msg("Failed to decode version response")
return
}
// Skip updates if server is also in development mode
if versionResp.Version == "dev" {
a.logger.Debug().Msg("Skipping update - server is in development mode")
return
}
// Compare versions
if versionResp.Version == Version {
a.logger.Debug().Str("version", Version).Msg("Agent is up to date")
return
}
a.logger.Info().
Str("currentVersion", Version).
Str("availableVersion", versionResp.Version).
Msg("New agent version available, performing self-update")
// Perform self-update
if err := a.selfUpdate(ctx); err != nil {
a.logger.Error().Err(err).Msg("Failed to self-update agent")
return
}
a.logger.Info().Msg("Agent updated successfully, restarting...")
}
func determineSelfUpdateArch() string {
switch runtime.GOARCH {
case "amd64":
return "linux-amd64"
case "arm64":
return "linux-arm64"
case "arm":
return "linux-armv7"
}
out, err := exec.Command("uname", "-m").Output()
if err != nil {
return ""
}
normalized := strings.ToLower(strings.TrimSpace(string(out)))
switch normalized {
case "x86_64", "amd64":
return "linux-amd64"
case "aarch64", "arm64":
return "linux-arm64"
case "armv7l", "armhf", "armv7":
return "linux-armv7"
default:
return ""
}
}
// selfUpdate downloads the new agent binary and replaces the current one
func (a *Agent) selfUpdate(ctx context.Context) error {
target := a.primaryTarget()
if target.URL == "" {
return errors.New("no Pulse target configured for self-update")
}
// Get path to current executable
execPath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
downloadBase := strings.TrimRight(target.URL, "/") + "/download/pulse-docker-agent"
archParam := determineSelfUpdateArch()
type downloadCandidate struct {
url string
arch string
}
candidates := make([]downloadCandidate, 0, 2)
if archParam != "" {
candidates = append(candidates, downloadCandidate{
url: fmt.Sprintf("%s?arch=%s", downloadBase, archParam),
arch: archParam,
})
}
candidates = append(candidates, downloadCandidate{url: downloadBase})
client := a.httpClientFor(target)
var resp *http.Response
var lastErr error
for _, candidate := range candidates {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, candidate.url, nil)
if err != nil {
lastErr = fmt.Errorf("failed to create download request: %w", err)
continue
}
if target.Token != "" {
req.Header.Set("X-API-Token", target.Token)
req.Header.Set("Authorization", "Bearer "+target.Token)
}
response, err := client.Do(req)
if err != nil {
lastErr = fmt.Errorf("failed to download new binary: %w", err)
continue
}
if response.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("download failed with status: %s", response.Status)
response.Body.Close()
continue
}
resp = response
if candidate.arch != "" {
a.logger.Debug().
Str("arch", candidate.arch).
Msg("Self-update: downloaded architecture-specific agent binary")
} else if archParam != "" {
a.logger.Debug().Msg("Self-update: falling back to server default agent binary")
}
break
}
if resp == nil {
if lastErr == nil {
lastErr = errors.New("failed to download new binary")
}
return lastErr
}
defer resp.Body.Close()
checksumHeader := strings.TrimSpace(resp.Header.Get("X-Checksum-Sha256"))
// Create temporary file
tmpFile, err := os.CreateTemp("", "pulse-docker-agent-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath) // Clean up if something goes wrong
// Write downloaded binary to temp file
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write downloaded binary: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
downloadChecksum := hex.EncodeToString(hasher.Sum(nil))
if checksumHeader != "" {
expected := strings.ToLower(strings.TrimSpace(checksumHeader))
actual := strings.ToLower(downloadChecksum)
if expected != actual {
return fmt.Errorf("checksum verification failed: expected %s, got %s", expected, actual)
}
a.logger.Debug().Str("checksum", downloadChecksum).Msg("Self-update: checksum verified")
} else {
a.logger.Warn().Msg("Self-update: checksum header missing; skipping verification")
}
// Make temp file executable
if err := os.Chmod(tmpPath, 0755); err != nil {
return fmt.Errorf("failed to make temp file executable: %w", err)
}
// Create backup of current binary
backupPath := execPath + ".backup"
if err := os.Rename(execPath, backupPath); err != nil {
return fmt.Errorf("failed to backup current binary: %w", err)
}
// Move new binary to current location
if err := os.Rename(tmpPath, execPath); err != nil {
// Restore backup on failure
os.Rename(backupPath, execPath)
return fmt.Errorf("failed to replace binary: %w", err)
}
// Remove backup on success
os.Remove(backupPath)
// Restart agent with same arguments
args := os.Args
env := os.Environ()
if err := syscall.Exec(execPath, args, env); err != nil {
return fmt.Errorf("failed to restart agent: %w", err)
}
return nil
}