mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-19 07:50:43 +01:00
syscall.Exec is not supported on Windows, causing self-update to fail with "failed to restart: not supported by windows". Split restart logic into platform-specific files: - restart_unix.go: Uses syscall.Exec for in-place process replacement - restart_windows.go: Uses os.Exit(0) to let Windows SCM restart service Related to #735
473 lines
13 KiB
Go
473 lines
13 KiB
Go
// Package agentupdate provides self-update functionality for Pulse agents.
|
|
// It handles checking for new versions, downloading binaries, and performing
|
|
// atomic binary replacement with rollback support.
|
|
package agentupdate
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
const (
|
|
// maxBinarySize is the maximum allowed size for downloaded binaries (100 MB)
|
|
maxBinarySize = 100 * 1024 * 1024
|
|
|
|
// downloadTimeout is the maximum time allowed for downloading a binary
|
|
downloadTimeout = 5 * time.Minute
|
|
)
|
|
|
|
// Config holds the configuration for the updater.
|
|
type Config struct {
|
|
// PulseURL is the base URL of the Pulse server (e.g., "https://pulse.example.com:7655")
|
|
PulseURL string
|
|
|
|
// APIToken is the authentication token for the Pulse server
|
|
APIToken string
|
|
|
|
// AgentName is the name of the agent binary to download (e.g., "pulse-agent", "pulse-docker-agent")
|
|
AgentName string
|
|
|
|
// CurrentVersion is the version currently running
|
|
CurrentVersion string
|
|
|
|
// CheckInterval is how often to check for updates (default: 1 hour)
|
|
CheckInterval time.Duration
|
|
|
|
// InsecureSkipVerify skips TLS certificate verification
|
|
InsecureSkipVerify bool
|
|
|
|
// Logger is the zerolog logger instance
|
|
Logger *zerolog.Logger
|
|
|
|
// Disabled skips all update checks when true
|
|
Disabled bool
|
|
}
|
|
|
|
// Updater handles automatic updates for Pulse agents.
|
|
type Updater struct {
|
|
cfg Config
|
|
client *http.Client
|
|
logger zerolog.Logger
|
|
}
|
|
|
|
// New creates a new Updater with the given configuration.
|
|
func New(cfg Config) *Updater {
|
|
if cfg.CheckInterval == 0 {
|
|
cfg.CheckInterval = 1 * time.Hour
|
|
}
|
|
|
|
logger := zerolog.Nop()
|
|
if cfg.Logger != nil {
|
|
logger = *cfg.Logger
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: cfg.InsecureSkipVerify, //nolint:gosec
|
|
},
|
|
}
|
|
|
|
return &Updater{
|
|
cfg: cfg,
|
|
client: &http.Client{
|
|
Transport: transport,
|
|
Timeout: downloadTimeout,
|
|
},
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// RunLoop starts the update check loop. It blocks until the context is cancelled.
|
|
func (u *Updater) RunLoop(ctx context.Context) {
|
|
if u.cfg.Disabled {
|
|
u.logger.Info().Msg("Auto-update disabled")
|
|
return
|
|
}
|
|
|
|
if u.cfg.CurrentVersion == "dev" {
|
|
u.logger.Debug().Msg("Auto-update disabled in development mode")
|
|
return
|
|
}
|
|
|
|
// Initial check after a short delay
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(30 * time.Second):
|
|
u.CheckAndUpdate(ctx)
|
|
}
|
|
|
|
ticker := time.NewTicker(u.cfg.CheckInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
u.CheckAndUpdate(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
// CheckAndUpdate checks for a new version and performs the update if available.
|
|
func (u *Updater) CheckAndUpdate(ctx context.Context) {
|
|
if u.cfg.Disabled {
|
|
return
|
|
}
|
|
|
|
if u.cfg.CurrentVersion == "dev" {
|
|
u.logger.Debug().Msg("Skipping update check - running in development mode")
|
|
return
|
|
}
|
|
|
|
if u.cfg.PulseURL == "" {
|
|
u.logger.Debug().Msg("Skipping update check - no Pulse URL configured")
|
|
return
|
|
}
|
|
|
|
u.logger.Debug().Msg("Checking for agent updates")
|
|
|
|
serverVersion, err := u.getServerVersion(ctx)
|
|
if err != nil {
|
|
u.logger.Warn().Err(err).Msg("Failed to check for updates")
|
|
return
|
|
}
|
|
|
|
if serverVersion == "dev" {
|
|
u.logger.Debug().Msg("Skipping update - server is in development mode")
|
|
return
|
|
}
|
|
|
|
// Normalize both versions by stripping "v" prefix for comparison.
|
|
// Server returns version without prefix (e.g., "4.33.1"), but agent's
|
|
// CurrentVersion may include it (e.g., "v4.33.1") depending on build.
|
|
if utils.NormalizeVersion(serverVersion) == utils.NormalizeVersion(u.cfg.CurrentVersion) {
|
|
u.logger.Debug().Str("version", u.cfg.CurrentVersion).Msg("Agent is up to date")
|
|
return
|
|
}
|
|
|
|
u.logger.Info().
|
|
Str("currentVersion", u.cfg.CurrentVersion).
|
|
Str("availableVersion", serverVersion).
|
|
Msg("New agent version available, performing self-update")
|
|
|
|
if err := u.performUpdate(ctx); err != nil {
|
|
u.logger.Error().Err(err).Msg("Failed to self-update agent")
|
|
return
|
|
}
|
|
|
|
u.logger.Info().Msg("Agent updated successfully, restarting...")
|
|
}
|
|
|
|
// getServerVersion fetches the current version from the Pulse server.
|
|
func (u *Updater) getServerVersion(ctx context.Context) (string, error) {
|
|
url := fmt.Sprintf("%s/api/agent/version", strings.TrimRight(u.cfg.PulseURL, "/"))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
if u.cfg.APIToken != "" {
|
|
req.Header.Set("X-API-Token", u.cfg.APIToken)
|
|
req.Header.Set("Authorization", "Bearer "+u.cfg.APIToken)
|
|
}
|
|
|
|
resp, err := u.client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("server returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
var versionResp struct {
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&versionResp); err != nil {
|
|
return "", fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return versionResp.Version, nil
|
|
}
|
|
|
|
// isUnraid checks if we're running on Unraid by looking for /etc/unraid-version
|
|
func isUnraid() bool {
|
|
_, err := os.Stat("/etc/unraid-version")
|
|
return err == nil
|
|
}
|
|
|
|
// verifyBinaryMagic checks that the file is a valid executable for the current platform
|
|
func verifyBinaryMagic(path string) error {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
magic := make([]byte, 4)
|
|
if _, err := io.ReadFull(f, magic); err != nil {
|
|
return fmt.Errorf("failed to read magic bytes: %w", err)
|
|
}
|
|
|
|
switch runtime.GOOS {
|
|
case "linux":
|
|
// ELF magic: 0x7f 'E' 'L' 'F'
|
|
if magic[0] == 0x7f && magic[1] == 'E' && magic[2] == 'L' && magic[3] == 'F' {
|
|
return nil
|
|
}
|
|
return errors.New("not a valid ELF binary")
|
|
|
|
case "darwin":
|
|
// Mach-O magic bytes (little-endian):
|
|
// - 0xfeedface (32-bit)
|
|
// - 0xfeedfacf (64-bit)
|
|
// - 0xcafebabe (universal/fat binary)
|
|
// Note: bytes are reversed due to little-endian
|
|
if (magic[0] == 0xcf && magic[1] == 0xfa && magic[2] == 0xed && magic[3] == 0xfe) || // 64-bit
|
|
(magic[0] == 0xce && magic[1] == 0xfa && magic[2] == 0xed && magic[3] == 0xfe) || // 32-bit
|
|
(magic[0] == 0xca && magic[1] == 0xfe && magic[2] == 0xba && magic[3] == 0xbe) { // universal
|
|
return nil
|
|
}
|
|
return errors.New("not a valid Mach-O binary")
|
|
|
|
case "windows":
|
|
// PE magic: 'M' 'Z'
|
|
if magic[0] == 'M' && magic[1] == 'Z' {
|
|
return nil
|
|
}
|
|
return errors.New("not a valid PE binary")
|
|
|
|
default:
|
|
// Unknown platform, skip verification
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// unraidPersistentPath returns the path where the binary should be persisted on Unraid
|
|
func unraidPersistentPath(agentName string) string {
|
|
return fmt.Sprintf("/boot/config/plugins/%s/%s", agentName, agentName)
|
|
}
|
|
|
|
// performUpdate downloads and installs the new agent binary.
|
|
func (u *Updater) performUpdate(ctx context.Context) error {
|
|
execPath, err := os.Executable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get executable path: %w", err)
|
|
}
|
|
|
|
// Build download URL
|
|
downloadBase := fmt.Sprintf("%s/download/%s", strings.TrimRight(u.cfg.PulseURL, "/"), u.cfg.AgentName)
|
|
archParam := determineArch()
|
|
|
|
// Try architecture-specific binary first, then fall back to default
|
|
candidates := []string{}
|
|
if archParam != "" {
|
|
candidates = append(candidates, fmt.Sprintf("%s?arch=%s", downloadBase, archParam))
|
|
}
|
|
candidates = append(candidates, downloadBase)
|
|
|
|
var resp *http.Response
|
|
var lastErr error
|
|
|
|
for _, url := range candidates {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("failed to create download request: %w", err)
|
|
continue
|
|
}
|
|
|
|
if u.cfg.APIToken != "" {
|
|
req.Header.Set("X-API-Token", u.cfg.APIToken)
|
|
req.Header.Set("Authorization", "Bearer "+u.cfg.APIToken)
|
|
}
|
|
|
|
response, err := u.client.Do(req)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("failed to download binary: %w", err)
|
|
continue
|
|
}
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
lastErr = fmt.Errorf("download failed with status: %s", response.Status)
|
|
response.Body.Close()
|
|
continue
|
|
}
|
|
|
|
resp = response
|
|
u.logger.Debug().Str("url", url).Msg("Downloaded agent binary")
|
|
break
|
|
}
|
|
|
|
if resp == nil {
|
|
if lastErr == nil {
|
|
lastErr = errors.New("failed to download binary")
|
|
}
|
|
return lastErr
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Verify checksum if provided
|
|
checksumHeader := strings.TrimSpace(resp.Header.Get("X-Checksum-Sha256"))
|
|
|
|
// Resolve symlinks to get the real path for atomic rename
|
|
realExecPath, err := filepath.EvalSymlinks(execPath)
|
|
if err != nil {
|
|
// Fall back to original path if symlink resolution fails
|
|
realExecPath = execPath
|
|
}
|
|
|
|
// Create temporary file in the same directory as the target binary
|
|
// to ensure atomic rename works (os.Rename fails across filesystems)
|
|
targetDir := filepath.Dir(realExecPath)
|
|
tmpFile, err := os.CreateTemp(targetDir, u.cfg.AgentName+"-*.tmp")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create temp file: %w", err)
|
|
}
|
|
tmpPath := tmpFile.Name()
|
|
defer os.Remove(tmpPath) // Clean up on failure
|
|
|
|
// Write downloaded binary with checksum calculation and size limit
|
|
hasher := sha256.New()
|
|
limitedReader := io.LimitReader(resp.Body, maxBinarySize+1) // +1 to detect overflow
|
|
written, err := io.Copy(tmpFile, io.TeeReader(limitedReader, hasher))
|
|
if err != nil {
|
|
tmpFile.Close()
|
|
return fmt.Errorf("failed to write binary: %w", err)
|
|
}
|
|
if written > maxBinarySize {
|
|
tmpFile.Close()
|
|
return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySize)
|
|
}
|
|
if err := tmpFile.Close(); err != nil {
|
|
return fmt.Errorf("failed to close temp file: %w", err)
|
|
}
|
|
|
|
// Verify it's a valid executable (basic sanity check)
|
|
if err := verifyBinaryMagic(tmpPath); err != nil {
|
|
return fmt.Errorf("downloaded file is not a valid executable: %w", err)
|
|
}
|
|
|
|
// Verify checksum (mandatory for security)
|
|
downloadChecksum := hex.EncodeToString(hasher.Sum(nil))
|
|
if checksumHeader == "" {
|
|
return fmt.Errorf("server did not provide checksum header (X-Checksum-Sha256); refusing update for security")
|
|
}
|
|
|
|
expected := strings.ToLower(strings.TrimSpace(checksumHeader))
|
|
actual := strings.ToLower(downloadChecksum)
|
|
if expected != actual {
|
|
return fmt.Errorf("checksum mismatch: expected %s, got %s", expected, actual)
|
|
}
|
|
u.logger.Debug().Str("checksum", downloadChecksum).Msg("Checksum verified")
|
|
|
|
// Make executable
|
|
if err := os.Chmod(tmpPath, 0755); err != nil {
|
|
return fmt.Errorf("failed to chmod: %w", err)
|
|
}
|
|
|
|
// Atomic replacement with backup (use realExecPath for rename operations)
|
|
backupPath := realExecPath + ".backup"
|
|
if err := os.Rename(realExecPath, backupPath); err != nil {
|
|
return fmt.Errorf("failed to backup current binary: %w", err)
|
|
}
|
|
|
|
if err := os.Rename(tmpPath, realExecPath); err != nil {
|
|
// Restore backup on failure
|
|
os.Rename(backupPath, realExecPath)
|
|
return fmt.Errorf("failed to replace binary: %w", err)
|
|
}
|
|
|
|
// Remove backup on success
|
|
os.Remove(backupPath)
|
|
|
|
// On Unraid, also update the persistent copy on the flash drive
|
|
// This ensures the update survives reboots
|
|
if isUnraid() {
|
|
persistPath := unraidPersistentPath(u.cfg.AgentName)
|
|
if _, err := os.Stat(persistPath); err == nil {
|
|
// Persistent path exists, update it
|
|
u.logger.Debug().Str("path", persistPath).Msg("Updating Unraid persistent binary")
|
|
|
|
// Read the newly installed binary
|
|
newBinary, err := os.ReadFile(execPath)
|
|
if err != nil {
|
|
u.logger.Warn().Err(err).Msg("Failed to read new binary for Unraid persistence")
|
|
} else {
|
|
// Write to persistent storage (atomic via temp file)
|
|
tmpPersist := persistPath + ".tmp"
|
|
if err := os.WriteFile(tmpPersist, newBinary, 0644); err != nil {
|
|
u.logger.Warn().Err(err).Msg("Failed to write Unraid persistent binary")
|
|
} else if err := os.Rename(tmpPersist, persistPath); err != nil {
|
|
u.logger.Warn().Err(err).Msg("Failed to rename Unraid persistent binary")
|
|
os.Remove(tmpPersist)
|
|
} else {
|
|
u.logger.Info().Str("path", persistPath).Msg("Updated Unraid persistent binary")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Restart the process using platform-specific implementation
|
|
return restartProcess(execPath)
|
|
}
|
|
|
|
// determineArch returns the architecture string for download URLs (e.g., "linux-amd64", "darwin-arm64").
|
|
func determineArch() string {
|
|
os := runtime.GOOS
|
|
arch := runtime.GOARCH
|
|
|
|
// Normalize architecture
|
|
switch arch {
|
|
case "arm":
|
|
arch = "armv7"
|
|
case "386":
|
|
arch = "386"
|
|
}
|
|
|
|
// For known OS/arch combinations, return directly
|
|
switch os {
|
|
case "linux", "darwin", "windows":
|
|
return fmt.Sprintf("%s-%s", os, arch)
|
|
}
|
|
|
|
// Fall back to uname for edge cases on unknown OS
|
|
out, err := exec.Command("uname", "-m").Output()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
normalized := strings.ToLower(strings.TrimSpace(string(out)))
|
|
switch normalized {
|
|
case "x86_64", "amd64":
|
|
return "linux-amd64"
|
|
case "aarch64", "arm64":
|
|
return "linux-arm64"
|
|
case "armv7l", "armhf", "armv7":
|
|
return "linux-armv7"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|