mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
1034 lines
28 KiB
Go
1034 lines
28 KiB
Go
package agentupdate
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
return f(r)
|
|
}
|
|
|
|
func testBinary() []byte {
|
|
return []byte{0x7f, 'E', 'L', 'F', 0x01, 0x02, 0x03, 0x04}
|
|
}
|
|
|
|
func checksum(data []byte) string {
|
|
sum := sha256.Sum256(data)
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func writeTempExec(t *testing.T) (string, string) {
|
|
t.Helper()
|
|
|
|
dir := t.TempDir()
|
|
execPath := filepath.Join(dir, "agent")
|
|
if err := os.WriteFile(execPath, []byte("old-binary"), 0755); err != nil {
|
|
t.Fatalf("write exec: %v", err)
|
|
}
|
|
return dir, execPath
|
|
}
|
|
|
|
func newUpdaterForTest(serverURL string) *Updater {
|
|
cfg := Config{
|
|
PulseURL: serverURL,
|
|
AgentName: "pulse-agent",
|
|
CurrentVersion: "1.0.0",
|
|
CheckInterval: 10 * time.Millisecond,
|
|
}
|
|
return New(cfg)
|
|
}
|
|
|
|
func TestRestartProcess(t *testing.T) {
|
|
orig := execFn
|
|
t.Cleanup(func() { execFn = orig })
|
|
|
|
execFn = func(string, []string, []string) error { return nil }
|
|
if err := restartProcess("/bin/true"); err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
|
|
execFn = func(string, []string, []string) error { return errors.New("boom") }
|
|
if err := restartProcess("/bin/false"); err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
}
|
|
|
|
func TestDetermineArchOverrides(t *testing.T) {
|
|
origOS, origArch, origUname := runtimeGOOS, runtimeGOARCH, unameCommand
|
|
t.Cleanup(func() {
|
|
runtimeGOOS = origOS
|
|
runtimeGOARCH = origArch
|
|
unameCommand = origUname
|
|
})
|
|
|
|
runtimeGOOS = "linux"
|
|
runtimeGOARCH = "arm"
|
|
if got := determineArch(); got != "linux-armv7" {
|
|
t.Fatalf("expected linux-armv7, got %q", got)
|
|
}
|
|
|
|
runtimeGOARCH = "386"
|
|
if got := determineArch(); got != "linux-386" {
|
|
t.Fatalf("expected linux-386, got %q", got)
|
|
}
|
|
|
|
runtimeGOOS = "solaris"
|
|
runtimeGOARCH = "amd64"
|
|
unameCommand = func() ([]byte, error) { return []byte("aarch64"), nil }
|
|
if got := determineArch(); got != "linux-arm64" {
|
|
t.Fatalf("expected linux-arm64, got %q", got)
|
|
}
|
|
|
|
unameCommand = func() ([]byte, error) { return []byte("x86_64"), nil }
|
|
if got := determineArch(); got != "linux-amd64" {
|
|
t.Fatalf("expected linux-amd64, got %q", got)
|
|
}
|
|
|
|
unameCommand = func() ([]byte, error) { return []byte("armv7l"), nil }
|
|
if got := determineArch(); got != "linux-armv7" {
|
|
t.Fatalf("expected linux-armv7, got %q", got)
|
|
}
|
|
|
|
unameCommand = func() ([]byte, error) { return []byte("mips"), nil }
|
|
if got := determineArch(); got != "" {
|
|
t.Fatalf("expected empty arch for unknown uname, got %q", got)
|
|
}
|
|
|
|
unameCommand = func() ([]byte, error) { return nil, errors.New("fail") }
|
|
if got := determineArch(); got != "" {
|
|
t.Fatalf("expected empty arch on uname error, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestUnameCommandDefault(t *testing.T) {
|
|
if _, err := exec.LookPath("uname"); err != nil {
|
|
t.Skip("uname not available")
|
|
}
|
|
|
|
orig := unameCommand
|
|
t.Cleanup(func() { unameCommand = orig })
|
|
unameCommand = orig
|
|
|
|
if _, err := unameCommand(); err != nil {
|
|
t.Fatalf("expected uname to run, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestVerifyBinaryMagicOverrides(t *testing.T) {
|
|
origOS := runtimeGOOS
|
|
t.Cleanup(func() { runtimeGOOS = origOS })
|
|
|
|
tmpDir := t.TempDir()
|
|
|
|
runtimeGOOS = "darwin"
|
|
machoPath := filepath.Join(tmpDir, "macho")
|
|
machoData := []byte{0xcf, 0xfa, 0xed, 0xfe, 0x00}
|
|
if err := os.WriteFile(machoPath, machoData, 0644); err != nil {
|
|
t.Fatalf("write macho: %v", err)
|
|
}
|
|
if err := verifyBinaryMagic(machoPath); err != nil {
|
|
t.Fatalf("expected macho to validate, got %v", err)
|
|
}
|
|
|
|
badPath := filepath.Join(tmpDir, "macho-bad")
|
|
if err := os.WriteFile(badPath, []byte{0x00, 0x00, 0x00, 0x00}, 0644); err != nil {
|
|
t.Fatalf("write bad macho: %v", err)
|
|
}
|
|
if err := verifyBinaryMagic(badPath); err == nil {
|
|
t.Fatalf("expected macho error")
|
|
}
|
|
|
|
runtimeGOOS = "windows"
|
|
pePath := filepath.Join(tmpDir, "pe.exe")
|
|
if err := os.WriteFile(pePath, []byte{'M', 'Z', 0x00, 0x00}, 0644); err != nil {
|
|
t.Fatalf("write pe: %v", err)
|
|
}
|
|
if err := verifyBinaryMagic(pePath); err != nil {
|
|
t.Fatalf("expected PE to validate, got %v", err)
|
|
}
|
|
|
|
badPEPath := filepath.Join(tmpDir, "bad-pe.exe")
|
|
if err := os.WriteFile(badPEPath, []byte{0x00, 0x00, 0x00, 0x00}, 0644); err != nil {
|
|
t.Fatalf("write bad pe: %v", err)
|
|
}
|
|
if err := verifyBinaryMagic(badPEPath); err == nil {
|
|
t.Fatalf("expected PE error")
|
|
}
|
|
|
|
runtimeGOOS = "plan9"
|
|
planPath := filepath.Join(tmpDir, "plan9")
|
|
if err := os.WriteFile(planPath, []byte{0x00, 0x01, 0x02, 0x03}, 0644); err != nil {
|
|
t.Fatalf("write plan9: %v", err)
|
|
}
|
|
if err := verifyBinaryMagic(planPath); err != nil {
|
|
t.Fatalf("expected unknown OS to skip verification, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestIsUnraidOverride(t *testing.T) {
|
|
orig := unraidVersionPath
|
|
t.Cleanup(func() { unraidVersionPath = orig })
|
|
|
|
tmpDir := t.TempDir()
|
|
unraidVersionPath = filepath.Join(tmpDir, "unraid-version")
|
|
if isUnraid() {
|
|
t.Fatalf("expected false when file missing")
|
|
}
|
|
if err := os.WriteFile(unraidVersionPath, []byte("6.12"), 0644); err != nil {
|
|
t.Fatalf("write unraid: %v", err)
|
|
}
|
|
if !isUnraid() {
|
|
t.Fatalf("expected true when file exists")
|
|
}
|
|
}
|
|
|
|
func TestGetServerVersion(t *testing.T) {
|
|
t.Run("Success", func(t *testing.T) {
|
|
var sawToken bool
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("X-API-Token") == "token" && strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
|
sawToken = true
|
|
}
|
|
_, _ = w.Write([]byte(`{"version":"1.2.3"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.cfg.APIToken = "token"
|
|
u.client = server.Client()
|
|
|
|
version, err := u.getServerVersion(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("getServerVersion error: %v", err)
|
|
}
|
|
if version != "1.2.3" {
|
|
t.Fatalf("expected version 1.2.3, got %q", version)
|
|
}
|
|
if !sawToken {
|
|
t.Fatalf("expected token headers to be set")
|
|
}
|
|
})
|
|
|
|
t.Run("InvalidURL", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://[::1")
|
|
if _, err := u.getServerVersion(context.Background()); err == nil {
|
|
t.Fatalf("expected error for invalid URL")
|
|
}
|
|
})
|
|
|
|
t.Run("RequestError", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
|
return nil, errors.New("boom")
|
|
}),
|
|
}
|
|
if _, err := u.getServerVersion(context.Background()); err == nil {
|
|
t.Fatalf("expected request error")
|
|
}
|
|
})
|
|
|
|
t.Run("StatusError", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
if _, err := u.getServerVersion(context.Background()); err == nil {
|
|
t.Fatalf("expected status error")
|
|
}
|
|
})
|
|
|
|
t.Run("DecodeError", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte("{"))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
if _, err := u.getServerVersion(context.Background()); err == nil {
|
|
t.Fatalf("expected decode error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCheckAndUpdateBranches(t *testing.T) {
|
|
t.Run("Disabled", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
u.cfg.Disabled = true
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not update when disabled")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("DevCurrent", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
u.cfg.CurrentVersion = "dev"
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not update in dev mode")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("NoPulseURL", func(t *testing.T) {
|
|
u := newUpdaterForTest("")
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not update without URL")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("ServerError", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not update on server error")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("ServerDev", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"version":"dev"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not update when server dev")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("UpToDate", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"version":"1.0.0"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not update when up to date")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("ServerOlder", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"version":"0.9.0"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
u.performUpdateFn = func(context.Context) error {
|
|
t.Fatalf("should not downgrade")
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
})
|
|
|
|
t.Run("ServerNewer", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"version":"1.1.0"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
var called int32
|
|
u.performUpdateFn = func(context.Context) error {
|
|
atomic.AddInt32(&called, 1)
|
|
return nil
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
if called != 1 {
|
|
t.Fatalf("expected update to be called")
|
|
}
|
|
})
|
|
|
|
t.Run("UpdateError", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"version":"1.1.0"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
var called int32
|
|
u.performUpdateFn = func(context.Context) error {
|
|
atomic.AddInt32(&called, 1)
|
|
return errors.New("fail")
|
|
}
|
|
u.CheckAndUpdate(context.Background())
|
|
if called != 1 {
|
|
t.Fatalf("expected update to be called")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestRunLoop(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"version":"1.1.0"}`))
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
u.initialDelay = 0
|
|
u.newTicker = func(d time.Duration) *time.Ticker {
|
|
return time.NewTicker(5 * time.Millisecond)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
var called int32
|
|
u.performUpdateFn = func(context.Context) error {
|
|
if atomic.AddInt32(&called, 1) >= 2 {
|
|
cancel()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
u.RunLoop(ctx)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(200 * time.Millisecond):
|
|
t.Fatalf("RunLoop did not exit")
|
|
}
|
|
|
|
if atomic.LoadInt32(&called) < 1 {
|
|
t.Fatalf("expected RunLoop to invoke update")
|
|
}
|
|
}
|
|
|
|
func TestRunLoopEarlyExit(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
u := newUpdaterForTest("http://example")
|
|
u.cfg.Disabled = true
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
u.RunLoop(ctx)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(50 * time.Millisecond):
|
|
t.Fatalf("expected RunLoop to exit quickly")
|
|
}
|
|
|
|
u2 := newUpdaterForTest("http://example")
|
|
u2.cfg.CurrentVersion = "dev"
|
|
|
|
done2 := make(chan struct{})
|
|
go func() {
|
|
u2.RunLoop(context.Background())
|
|
close(done2)
|
|
}()
|
|
|
|
select {
|
|
case <-done2:
|
|
case <-time.After(50 * time.Millisecond):
|
|
t.Fatalf("expected RunLoop to exit for dev")
|
|
}
|
|
}
|
|
|
|
func TestRunLoopInitialCancel(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
u := newUpdaterForTest("http://example")
|
|
u.initialDelay = 5 * time.Second
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
u.RunLoop(ctx)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(50 * time.Millisecond):
|
|
t.Fatalf("expected RunLoop to exit on cancel")
|
|
}
|
|
}
|
|
|
|
func TestGetUpdatedFromVersion(t *testing.T) {
|
|
origExec := osExecutableFn
|
|
origEval := evalSymlinksFn
|
|
t.Cleanup(func() {
|
|
osExecutableFn = origExec
|
|
evalSymlinksFn = origEval
|
|
})
|
|
|
|
tmpDir := t.TempDir()
|
|
execPath := filepath.Join(tmpDir, "agent")
|
|
infoPath := filepath.Join(tmpDir, ".pulse-update-info")
|
|
if err := os.WriteFile(infoPath, []byte("1.0.0"), 0644); err != nil {
|
|
t.Fatalf("write update info: %v", err)
|
|
}
|
|
|
|
osExecutableFn = func() (string, error) { return execPath, nil }
|
|
evalSymlinksFn = func(string) (string, error) { return execPath, nil }
|
|
|
|
version := GetUpdatedFromVersion()
|
|
if version != "1.0.0" {
|
|
t.Fatalf("expected version 1.0.0, got %q", version)
|
|
}
|
|
if _, err := os.Stat(infoPath); err == nil {
|
|
t.Fatalf("expected update info file to be removed")
|
|
}
|
|
|
|
missingPath := filepath.Join(tmpDir, "agent-missing")
|
|
osExecutableFn = func() (string, error) { return missingPath, nil }
|
|
evalSymlinksFn = func(string) (string, error) { return "", errors.New("fail") }
|
|
if GetUpdatedFromVersion() != "" {
|
|
t.Fatalf("expected empty version on missing file")
|
|
}
|
|
|
|
osExecutableFn = func() (string, error) { return "", errors.New("fail") }
|
|
if GetUpdatedFromVersion() != "" {
|
|
t.Fatalf("expected empty version on error")
|
|
}
|
|
}
|
|
|
|
func TestPerformUpdateWrapper(t *testing.T) {
|
|
t.Run("ExecPathError", func(t *testing.T) {
|
|
origExec := osExecutableFn
|
|
t.Cleanup(func() { osExecutableFn = origExec })
|
|
osExecutableFn = func() (string, error) { return "", errors.New("fail") }
|
|
|
|
u := newUpdaterForTest("http://example")
|
|
if err := u.performUpdate(context.Background()); err == nil {
|
|
t.Fatalf("expected exec path error")
|
|
}
|
|
})
|
|
|
|
t.Run("Success", func(t *testing.T) {
|
|
data := testBinary()
|
|
check := checksum(data)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Checksum-Sha256", check)
|
|
_, _ = w.Write(data)
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
|
|
origExec := osExecutableFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
osExecutableFn = origExec
|
|
restartProcessFn = origRestart
|
|
})
|
|
osExecutableFn = func() (string, error) { return execPath, nil }
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
if err := u.performUpdate(context.Background()); err != nil {
|
|
t.Fatalf("expected update success, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestPerformUpdateInvalidRequest(t *testing.T) {
|
|
u := newUpdaterForTest("http://[::1")
|
|
if err := u.performUpdateWithExecPath(context.Background(), ""); err == nil {
|
|
t.Fatalf("expected error for invalid URL")
|
|
}
|
|
}
|
|
|
|
func TestPerformUpdateDownloadErrorAndHeaders(t *testing.T) {
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest("http://example")
|
|
u.cfg.APIToken = "token"
|
|
|
|
var sawToken bool
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
if r.Header.Get("X-API-Token") == "token" && strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
|
sawToken = true
|
|
}
|
|
return nil, errors.New("download fail")
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected download error")
|
|
}
|
|
if !sawToken {
|
|
t.Fatalf("expected auth headers to be set")
|
|
}
|
|
}
|
|
|
|
func TestPerformUpdateStatusFallbackAndSuccess(t *testing.T) {
|
|
data := testBinary()
|
|
check := checksum(data)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.RawQuery, "arch=") {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
w.Header().Set("X-Checksum-Sha256", check)
|
|
_, _ = w.Write(data)
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
_, execPath := writeTempExec(t)
|
|
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() { restartProcessFn = origRestart })
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
|
|
t.Fatalf("expected update success, got %v", err)
|
|
}
|
|
|
|
updated, err := os.ReadFile(execPath)
|
|
if err != nil {
|
|
t.Fatalf("read updated exec: %v", err)
|
|
}
|
|
if !bytes.HasPrefix(updated, []byte{0x7f, 'E', 'L', 'F'}) {
|
|
t.Fatalf("expected updated binary content")
|
|
}
|
|
}
|
|
|
|
func TestPerformUpdateSymlinkFallback(t *testing.T) {
|
|
data := testBinary()
|
|
check := checksum(data)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Checksum-Sha256", check)
|
|
_, _ = w.Write(data)
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
|
|
origEval := evalSymlinksFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
evalSymlinksFn = origEval
|
|
restartProcessFn = origRestart
|
|
})
|
|
evalSymlinksFn = func(string) (string, error) { return "", errors.New("fail") }
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
|
|
t.Fatalf("expected update success with symlink fallback, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestPerformUpdateErrors(t *testing.T) {
|
|
t.Run("CreateTempError", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Checksum-Sha256", checksum(testBinary()))
|
|
_, _ = w.Write(testBinary())
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
|
|
origCreate := createTempFn
|
|
t.Cleanup(func() { createTempFn = origCreate })
|
|
createTempFn = func(string, string) (*os.File, error) {
|
|
return nil, errors.New("temp fail")
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected create temp error")
|
|
}
|
|
})
|
|
|
|
t.Run("CopyError", func(t *testing.T) {
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest("http://example")
|
|
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
body := io.NopCloser(&errorReader{})
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: body,
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(testBinary())}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected copy error")
|
|
}
|
|
})
|
|
|
|
t.Run("TooLarge", func(t *testing.T) {
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest("http://example")
|
|
|
|
origMax := maxBinarySizeBytes
|
|
t.Cleanup(func() { maxBinarySizeBytes = origMax })
|
|
maxBinarySizeBytes = 4
|
|
|
|
data := []byte{0x7f, 'E', 'L', 'F', 0x00, 0x01}
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected size error")
|
|
}
|
|
})
|
|
|
|
t.Run("CloseError", func(t *testing.T) {
|
|
data := testBinary()
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
|
|
origClose := closeFileFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
closeFileFn = origClose
|
|
restartProcessFn = origRestart
|
|
})
|
|
closeFileFn = func(*os.File) error { return errors.New("close fail") }
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected close error")
|
|
}
|
|
})
|
|
|
|
t.Run("InvalidBinary", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
|
|
data := []byte{0x00, 0x00, 0x00, 0x00}
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected invalid binary error")
|
|
}
|
|
})
|
|
|
|
t.Run("MissingChecksum", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
|
|
data := testBinary()
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected checksum missing error")
|
|
}
|
|
})
|
|
|
|
t.Run("ChecksumMismatch", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
|
|
data := testBinary()
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{"deadbeef"}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected checksum mismatch error")
|
|
}
|
|
})
|
|
|
|
t.Run("ChmodError", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
data := testBinary()
|
|
|
|
origChmod := chmodFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
chmodFn = origChmod
|
|
restartProcessFn = origRestart
|
|
})
|
|
chmodFn = func(string, os.FileMode) error { return errors.New("chmod fail") }
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected chmod error")
|
|
}
|
|
})
|
|
|
|
t.Run("BackupRenameError", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
data := testBinary()
|
|
|
|
origRename := renameFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
renameFn = origRename
|
|
restartProcessFn = origRestart
|
|
})
|
|
renameFn = func(oldPath, newPath string) error {
|
|
if strings.HasSuffix(newPath, ".backup") {
|
|
return errors.New("backup fail")
|
|
}
|
|
return origRename(oldPath, newPath)
|
|
}
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected backup rename error")
|
|
}
|
|
})
|
|
|
|
t.Run("ReplaceRenameError", func(t *testing.T) {
|
|
u := newUpdaterForTest("http://example")
|
|
_, execPath := writeTempExec(t)
|
|
data := testBinary()
|
|
|
|
origRename := renameFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
renameFn = origRename
|
|
restartProcessFn = origRestart
|
|
})
|
|
var calls int32
|
|
renameFn = func(oldPath, newPath string) error {
|
|
if atomic.AddInt32(&calls, 1) == 2 {
|
|
return errors.New("replace fail")
|
|
}
|
|
return origRename(oldPath, newPath)
|
|
}
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
u.client = &http.Client{
|
|
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Status: "200 OK",
|
|
Body: io.NopCloser(bytes.NewReader(data)),
|
|
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
|
|
t.Fatalf("expected replace rename error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestPerformUpdateUnraidPaths(t *testing.T) {
|
|
data := testBinary()
|
|
check := checksum(data)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Checksum-Sha256", check)
|
|
_, _ = w.Write(data)
|
|
}))
|
|
t.Cleanup(server.Close)
|
|
|
|
_, execPath := writeTempExec(t)
|
|
u := newUpdaterForTest(server.URL)
|
|
u.client = server.Client()
|
|
|
|
origUnraid := unraidVersionPath
|
|
origPersist := unraidPersistentPathFn
|
|
origRead := readFileFn
|
|
origWrite := writeFileFn
|
|
origRename := renameFn
|
|
origRestart := restartProcessFn
|
|
t.Cleanup(func() {
|
|
unraidVersionPath = origUnraid
|
|
unraidPersistentPathFn = origPersist
|
|
readFileFn = origRead
|
|
writeFileFn = origWrite
|
|
renameFn = origRename
|
|
restartProcessFn = origRestart
|
|
})
|
|
|
|
tmpDir := t.TempDir()
|
|
unraidVersionPath = filepath.Join(tmpDir, "unraid-version")
|
|
if err := os.WriteFile(unraidVersionPath, []byte("6.12"), 0644); err != nil {
|
|
t.Fatalf("write unraid version: %v", err)
|
|
}
|
|
|
|
persistDir := filepath.Join(tmpDir, "persist")
|
|
if err := os.MkdirAll(persistDir, 0755); err != nil {
|
|
t.Fatalf("mkdir persist: %v", err)
|
|
}
|
|
persistPath := filepath.Join(persistDir, "pulse-agent")
|
|
if err := os.WriteFile(persistPath, []byte("old"), 0644); err != nil {
|
|
t.Fatalf("write persist: %v", err)
|
|
}
|
|
unraidPersistentPathFn = func(string) string { return persistPath }
|
|
restartProcessFn = func(string) error { return nil }
|
|
|
|
t.Run("ReadError", func(t *testing.T) {
|
|
readFileFn = func(string) ([]byte, error) { return nil, errors.New("read fail") }
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
|
|
t.Fatalf("expected update success, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("WriteError", func(t *testing.T) {
|
|
readFileFn = func(string) ([]byte, error) { return []byte("new"), nil }
|
|
writeFileFn = func(string, []byte, os.FileMode) error { return errors.New("write fail") }
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
|
|
t.Fatalf("expected update success, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("RenameError", func(t *testing.T) {
|
|
readFileFn = func(string) ([]byte, error) { return []byte("new"), nil }
|
|
writeFileFn = os.WriteFile
|
|
renameFn = func(oldPath, newPath string) error {
|
|
if newPath == persistPath {
|
|
return errors.New("rename fail")
|
|
}
|
|
return os.Rename(oldPath, newPath)
|
|
}
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
|
|
t.Fatalf("expected update success, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Success", func(t *testing.T) {
|
|
readFileFn = func(string) ([]byte, error) { return []byte("new"), nil }
|
|
writeFileFn = os.WriteFile
|
|
renameFn = os.Rename
|
|
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
|
|
t.Fatalf("expected update success, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
type errorReader struct {
|
|
sent bool
|
|
}
|
|
|
|
func (e *errorReader) Read(p []byte) (int, error) {
|
|
if e.sent {
|
|
return 0, errors.New("read fail")
|
|
}
|
|
e.sent = true
|
|
copy(p, testBinary())
|
|
return len(testBinary()), errors.New("read fail")
|
|
}
|