Files
Pulse/internal/agentupdate/coverage_test.go
2025-12-29 17:25:21 +00:00

1034 lines
28 KiB
Go

package agentupdate
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func testBinary() []byte {
return []byte{0x7f, 'E', 'L', 'F', 0x01, 0x02, 0x03, 0x04}
}
func checksum(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
func writeTempExec(t *testing.T) (string, string) {
t.Helper()
dir := t.TempDir()
execPath := filepath.Join(dir, "agent")
if err := os.WriteFile(execPath, []byte("old-binary"), 0755); err != nil {
t.Fatalf("write exec: %v", err)
}
return dir, execPath
}
func newUpdaterForTest(serverURL string) *Updater {
cfg := Config{
PulseURL: serverURL,
AgentName: "pulse-agent",
CurrentVersion: "1.0.0",
CheckInterval: 10 * time.Millisecond,
}
return New(cfg)
}
func TestRestartProcess(t *testing.T) {
orig := execFn
t.Cleanup(func() { execFn = orig })
execFn = func(string, []string, []string) error { return nil }
if err := restartProcess("/bin/true"); err != nil {
t.Fatalf("expected nil error, got %v", err)
}
execFn = func(string, []string, []string) error { return errors.New("boom") }
if err := restartProcess("/bin/false"); err == nil {
t.Fatalf("expected error")
}
}
func TestDetermineArchOverrides(t *testing.T) {
origOS, origArch, origUname := runtimeGOOS, runtimeGOARCH, unameCommand
t.Cleanup(func() {
runtimeGOOS = origOS
runtimeGOARCH = origArch
unameCommand = origUname
})
runtimeGOOS = "linux"
runtimeGOARCH = "arm"
if got := determineArch(); got != "linux-armv7" {
t.Fatalf("expected linux-armv7, got %q", got)
}
runtimeGOARCH = "386"
if got := determineArch(); got != "linux-386" {
t.Fatalf("expected linux-386, got %q", got)
}
runtimeGOOS = "solaris"
runtimeGOARCH = "amd64"
unameCommand = func() ([]byte, error) { return []byte("aarch64"), nil }
if got := determineArch(); got != "linux-arm64" {
t.Fatalf("expected linux-arm64, got %q", got)
}
unameCommand = func() ([]byte, error) { return []byte("x86_64"), nil }
if got := determineArch(); got != "linux-amd64" {
t.Fatalf("expected linux-amd64, got %q", got)
}
unameCommand = func() ([]byte, error) { return []byte("armv7l"), nil }
if got := determineArch(); got != "linux-armv7" {
t.Fatalf("expected linux-armv7, got %q", got)
}
unameCommand = func() ([]byte, error) { return []byte("mips"), nil }
if got := determineArch(); got != "" {
t.Fatalf("expected empty arch for unknown uname, got %q", got)
}
unameCommand = func() ([]byte, error) { return nil, errors.New("fail") }
if got := determineArch(); got != "" {
t.Fatalf("expected empty arch on uname error, got %q", got)
}
}
func TestUnameCommandDefault(t *testing.T) {
if _, err := exec.LookPath("uname"); err != nil {
t.Skip("uname not available")
}
orig := unameCommand
t.Cleanup(func() { unameCommand = orig })
unameCommand = orig
if _, err := unameCommand(); err != nil {
t.Fatalf("expected uname to run, got %v", err)
}
}
func TestVerifyBinaryMagicOverrides(t *testing.T) {
origOS := runtimeGOOS
t.Cleanup(func() { runtimeGOOS = origOS })
tmpDir := t.TempDir()
runtimeGOOS = "darwin"
machoPath := filepath.Join(tmpDir, "macho")
machoData := []byte{0xcf, 0xfa, 0xed, 0xfe, 0x00}
if err := os.WriteFile(machoPath, machoData, 0644); err != nil {
t.Fatalf("write macho: %v", err)
}
if err := verifyBinaryMagic(machoPath); err != nil {
t.Fatalf("expected macho to validate, got %v", err)
}
badPath := filepath.Join(tmpDir, "macho-bad")
if err := os.WriteFile(badPath, []byte{0x00, 0x00, 0x00, 0x00}, 0644); err != nil {
t.Fatalf("write bad macho: %v", err)
}
if err := verifyBinaryMagic(badPath); err == nil {
t.Fatalf("expected macho error")
}
runtimeGOOS = "windows"
pePath := filepath.Join(tmpDir, "pe.exe")
if err := os.WriteFile(pePath, []byte{'M', 'Z', 0x00, 0x00}, 0644); err != nil {
t.Fatalf("write pe: %v", err)
}
if err := verifyBinaryMagic(pePath); err != nil {
t.Fatalf("expected PE to validate, got %v", err)
}
badPEPath := filepath.Join(tmpDir, "bad-pe.exe")
if err := os.WriteFile(badPEPath, []byte{0x00, 0x00, 0x00, 0x00}, 0644); err != nil {
t.Fatalf("write bad pe: %v", err)
}
if err := verifyBinaryMagic(badPEPath); err == nil {
t.Fatalf("expected PE error")
}
runtimeGOOS = "plan9"
planPath := filepath.Join(tmpDir, "plan9")
if err := os.WriteFile(planPath, []byte{0x00, 0x01, 0x02, 0x03}, 0644); err != nil {
t.Fatalf("write plan9: %v", err)
}
if err := verifyBinaryMagic(planPath); err != nil {
t.Fatalf("expected unknown OS to skip verification, got %v", err)
}
}
func TestIsUnraidOverride(t *testing.T) {
orig := unraidVersionPath
t.Cleanup(func() { unraidVersionPath = orig })
tmpDir := t.TempDir()
unraidVersionPath = filepath.Join(tmpDir, "unraid-version")
if isUnraid() {
t.Fatalf("expected false when file missing")
}
if err := os.WriteFile(unraidVersionPath, []byte("6.12"), 0644); err != nil {
t.Fatalf("write unraid: %v", err)
}
if !isUnraid() {
t.Fatalf("expected true when file exists")
}
}
func TestGetServerVersion(t *testing.T) {
t.Run("Success", func(t *testing.T) {
var sawToken bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-API-Token") == "token" && strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
sawToken = true
}
_, _ = w.Write([]byte(`{"version":"1.2.3"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.cfg.APIToken = "token"
u.client = server.Client()
version, err := u.getServerVersion(context.Background())
if err != nil {
t.Fatalf("getServerVersion error: %v", err)
}
if version != "1.2.3" {
t.Fatalf("expected version 1.2.3, got %q", version)
}
if !sawToken {
t.Fatalf("expected token headers to be set")
}
})
t.Run("InvalidURL", func(t *testing.T) {
u := newUpdaterForTest("http://[::1")
if _, err := u.getServerVersion(context.Background()); err == nil {
t.Fatalf("expected error for invalid URL")
}
})
t.Run("RequestError", func(t *testing.T) {
u := newUpdaterForTest("http://example")
u.client = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("boom")
}),
}
if _, err := u.getServerVersion(context.Background()); err == nil {
t.Fatalf("expected request error")
}
})
t.Run("StatusError", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
if _, err := u.getServerVersion(context.Background()); err == nil {
t.Fatalf("expected status error")
}
})
t.Run("DecodeError", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("{"))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
if _, err := u.getServerVersion(context.Background()); err == nil {
t.Fatalf("expected decode error")
}
})
}
func TestCheckAndUpdateBranches(t *testing.T) {
t.Run("Disabled", func(t *testing.T) {
u := newUpdaterForTest("http://example")
u.cfg.Disabled = true
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not update when disabled")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("DevCurrent", func(t *testing.T) {
u := newUpdaterForTest("http://example")
u.cfg.CurrentVersion = "dev"
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not update in dev mode")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("NoPulseURL", func(t *testing.T) {
u := newUpdaterForTest("")
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not update without URL")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("ServerError", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not update on server error")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("ServerDev", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"version":"dev"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not update when server dev")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("UpToDate", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"version":"1.0.0"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not update when up to date")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("ServerOlder", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"version":"0.9.0"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
u.performUpdateFn = func(context.Context) error {
t.Fatalf("should not downgrade")
return nil
}
u.CheckAndUpdate(context.Background())
})
t.Run("ServerNewer", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"version":"1.1.0"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
var called int32
u.performUpdateFn = func(context.Context) error {
atomic.AddInt32(&called, 1)
return nil
}
u.CheckAndUpdate(context.Background())
if called != 1 {
t.Fatalf("expected update to be called")
}
})
t.Run("UpdateError", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"version":"1.1.0"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
var called int32
u.performUpdateFn = func(context.Context) error {
atomic.AddInt32(&called, 1)
return errors.New("fail")
}
u.CheckAndUpdate(context.Background())
if called != 1 {
t.Fatalf("expected update to be called")
}
})
}
func TestRunLoop(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"version":"1.1.0"}`))
}))
t.Cleanup(server.Close)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
u.initialDelay = 0
u.newTicker = func(d time.Duration) *time.Ticker {
return time.NewTicker(5 * time.Millisecond)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var called int32
u.performUpdateFn = func(context.Context) error {
if atomic.AddInt32(&called, 1) >= 2 {
cancel()
}
return nil
}
done := make(chan struct{})
go func() {
u.RunLoop(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatalf("RunLoop did not exit")
}
if atomic.LoadInt32(&called) < 1 {
t.Fatalf("expected RunLoop to invoke update")
}
}
func TestRunLoopEarlyExit(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
u := newUpdaterForTest("http://example")
u.cfg.Disabled = true
done := make(chan struct{})
go func() {
u.RunLoop(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fatalf("expected RunLoop to exit quickly")
}
u2 := newUpdaterForTest("http://example")
u2.cfg.CurrentVersion = "dev"
done2 := make(chan struct{})
go func() {
u2.RunLoop(context.Background())
close(done2)
}()
select {
case <-done2:
case <-time.After(50 * time.Millisecond):
t.Fatalf("expected RunLoop to exit for dev")
}
}
func TestRunLoopInitialCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
u := newUpdaterForTest("http://example")
u.initialDelay = 5 * time.Second
done := make(chan struct{})
go func() {
u.RunLoop(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fatalf("expected RunLoop to exit on cancel")
}
}
func TestGetUpdatedFromVersion(t *testing.T) {
origExec := osExecutableFn
origEval := evalSymlinksFn
t.Cleanup(func() {
osExecutableFn = origExec
evalSymlinksFn = origEval
})
tmpDir := t.TempDir()
execPath := filepath.Join(tmpDir, "agent")
infoPath := filepath.Join(tmpDir, ".pulse-update-info")
if err := os.WriteFile(infoPath, []byte("1.0.0"), 0644); err != nil {
t.Fatalf("write update info: %v", err)
}
osExecutableFn = func() (string, error) { return execPath, nil }
evalSymlinksFn = func(string) (string, error) { return execPath, nil }
version := GetUpdatedFromVersion()
if version != "1.0.0" {
t.Fatalf("expected version 1.0.0, got %q", version)
}
if _, err := os.Stat(infoPath); err == nil {
t.Fatalf("expected update info file to be removed")
}
missingPath := filepath.Join(tmpDir, "agent-missing")
osExecutableFn = func() (string, error) { return missingPath, nil }
evalSymlinksFn = func(string) (string, error) { return "", errors.New("fail") }
if GetUpdatedFromVersion() != "" {
t.Fatalf("expected empty version on missing file")
}
osExecutableFn = func() (string, error) { return "", errors.New("fail") }
if GetUpdatedFromVersion() != "" {
t.Fatalf("expected empty version on error")
}
}
func TestPerformUpdateWrapper(t *testing.T) {
t.Run("ExecPathError", func(t *testing.T) {
origExec := osExecutableFn
t.Cleanup(func() { osExecutableFn = origExec })
osExecutableFn = func() (string, error) { return "", errors.New("fail") }
u := newUpdaterForTest("http://example")
if err := u.performUpdate(context.Background()); err == nil {
t.Fatalf("expected exec path error")
}
})
t.Run("Success", func(t *testing.T) {
data := testBinary()
check := checksum(data)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Checksum-Sha256", check)
_, _ = w.Write(data)
}))
t.Cleanup(server.Close)
_, execPath := writeTempExec(t)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
origExec := osExecutableFn
origRestart := restartProcessFn
t.Cleanup(func() {
osExecutableFn = origExec
restartProcessFn = origRestart
})
osExecutableFn = func() (string, error) { return execPath, nil }
restartProcessFn = func(string) error { return nil }
if err := u.performUpdate(context.Background()); err != nil {
t.Fatalf("expected update success, got %v", err)
}
})
}
func TestPerformUpdateInvalidRequest(t *testing.T) {
u := newUpdaterForTest("http://[::1")
if err := u.performUpdateWithExecPath(context.Background(), ""); err == nil {
t.Fatalf("expected error for invalid URL")
}
}
func TestPerformUpdateDownloadErrorAndHeaders(t *testing.T) {
_, execPath := writeTempExec(t)
u := newUpdaterForTest("http://example")
u.cfg.APIToken = "token"
var sawToken bool
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get("X-API-Token") == "token" && strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
sawToken = true
}
return nil, errors.New("download fail")
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected download error")
}
if !sawToken {
t.Fatalf("expected auth headers to be set")
}
}
func TestPerformUpdateStatusFallbackAndSuccess(t *testing.T) {
data := testBinary()
check := checksum(data)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.RawQuery, "arch=") {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("X-Checksum-Sha256", check)
_, _ = w.Write(data)
}))
t.Cleanup(server.Close)
_, execPath := writeTempExec(t)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
origRestart := restartProcessFn
t.Cleanup(func() { restartProcessFn = origRestart })
restartProcessFn = func(string) error { return nil }
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
t.Fatalf("expected update success, got %v", err)
}
updated, err := os.ReadFile(execPath)
if err != nil {
t.Fatalf("read updated exec: %v", err)
}
if !bytes.HasPrefix(updated, []byte{0x7f, 'E', 'L', 'F'}) {
t.Fatalf("expected updated binary content")
}
}
func TestPerformUpdateSymlinkFallback(t *testing.T) {
data := testBinary()
check := checksum(data)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Checksum-Sha256", check)
_, _ = w.Write(data)
}))
t.Cleanup(server.Close)
_, execPath := writeTempExec(t)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
origEval := evalSymlinksFn
origRestart := restartProcessFn
t.Cleanup(func() {
evalSymlinksFn = origEval
restartProcessFn = origRestart
})
evalSymlinksFn = func(string) (string, error) { return "", errors.New("fail") }
restartProcessFn = func(string) error { return nil }
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
t.Fatalf("expected update success with symlink fallback, got %v", err)
}
}
func TestPerformUpdateErrors(t *testing.T) {
t.Run("CreateTempError", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Checksum-Sha256", checksum(testBinary()))
_, _ = w.Write(testBinary())
}))
t.Cleanup(server.Close)
_, execPath := writeTempExec(t)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
origCreate := createTempFn
t.Cleanup(func() { createTempFn = origCreate })
createTempFn = func(string, string) (*os.File, error) {
return nil, errors.New("temp fail")
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected create temp error")
}
})
t.Run("CopyError", func(t *testing.T) {
_, execPath := writeTempExec(t)
u := newUpdaterForTest("http://example")
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
body := io.NopCloser(&errorReader{})
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: body,
Header: http.Header{"X-Checksum-Sha256": []string{checksum(testBinary())}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected copy error")
}
})
t.Run("TooLarge", func(t *testing.T) {
_, execPath := writeTempExec(t)
u := newUpdaterForTest("http://example")
origMax := maxBinarySizeBytes
t.Cleanup(func() { maxBinarySizeBytes = origMax })
maxBinarySizeBytes = 4
data := []byte{0x7f, 'E', 'L', 'F', 0x00, 0x01}
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected size error")
}
})
t.Run("CloseError", func(t *testing.T) {
data := testBinary()
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
origClose := closeFileFn
origRestart := restartProcessFn
t.Cleanup(func() {
closeFileFn = origClose
restartProcessFn = origRestart
})
closeFileFn = func(*os.File) error { return errors.New("close fail") }
restartProcessFn = func(string) error { return nil }
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected close error")
}
})
t.Run("InvalidBinary", func(t *testing.T) {
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
data := []byte{0x00, 0x00, 0x00, 0x00}
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected invalid binary error")
}
})
t.Run("MissingChecksum", func(t *testing.T) {
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
data := testBinary()
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected checksum missing error")
}
})
t.Run("ChecksumMismatch", func(t *testing.T) {
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
data := testBinary()
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{"deadbeef"}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected checksum mismatch error")
}
})
t.Run("ChmodError", func(t *testing.T) {
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
data := testBinary()
origChmod := chmodFn
origRestart := restartProcessFn
t.Cleanup(func() {
chmodFn = origChmod
restartProcessFn = origRestart
})
chmodFn = func(string, os.FileMode) error { return errors.New("chmod fail") }
restartProcessFn = func(string) error { return nil }
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected chmod error")
}
})
t.Run("BackupRenameError", func(t *testing.T) {
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
data := testBinary()
origRename := renameFn
origRestart := restartProcessFn
t.Cleanup(func() {
renameFn = origRename
restartProcessFn = origRestart
})
renameFn = func(oldPath, newPath string) error {
if strings.HasSuffix(newPath, ".backup") {
return errors.New("backup fail")
}
return origRename(oldPath, newPath)
}
restartProcessFn = func(string) error { return nil }
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected backup rename error")
}
})
t.Run("ReplaceRenameError", func(t *testing.T) {
u := newUpdaterForTest("http://example")
_, execPath := writeTempExec(t)
data := testBinary()
origRename := renameFn
origRestart := restartProcessFn
t.Cleanup(func() {
renameFn = origRename
restartProcessFn = origRestart
})
var calls int32
renameFn = func(oldPath, newPath string) error {
if atomic.AddInt32(&calls, 1) == 2 {
return errors.New("replace fail")
}
return origRename(oldPath, newPath)
}
restartProcessFn = func(string) error { return nil }
u.client = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(bytes.NewReader(data)),
Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}},
}, nil
}),
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil {
t.Fatalf("expected replace rename error")
}
})
}
func TestPerformUpdateUnraidPaths(t *testing.T) {
data := testBinary()
check := checksum(data)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Checksum-Sha256", check)
_, _ = w.Write(data)
}))
t.Cleanup(server.Close)
_, execPath := writeTempExec(t)
u := newUpdaterForTest(server.URL)
u.client = server.Client()
origUnraid := unraidVersionPath
origPersist := unraidPersistentPathFn
origRead := readFileFn
origWrite := writeFileFn
origRename := renameFn
origRestart := restartProcessFn
t.Cleanup(func() {
unraidVersionPath = origUnraid
unraidPersistentPathFn = origPersist
readFileFn = origRead
writeFileFn = origWrite
renameFn = origRename
restartProcessFn = origRestart
})
tmpDir := t.TempDir()
unraidVersionPath = filepath.Join(tmpDir, "unraid-version")
if err := os.WriteFile(unraidVersionPath, []byte("6.12"), 0644); err != nil {
t.Fatalf("write unraid version: %v", err)
}
persistDir := filepath.Join(tmpDir, "persist")
if err := os.MkdirAll(persistDir, 0755); err != nil {
t.Fatalf("mkdir persist: %v", err)
}
persistPath := filepath.Join(persistDir, "pulse-agent")
if err := os.WriteFile(persistPath, []byte("old"), 0644); err != nil {
t.Fatalf("write persist: %v", err)
}
unraidPersistentPathFn = func(string) string { return persistPath }
restartProcessFn = func(string) error { return nil }
t.Run("ReadError", func(t *testing.T) {
readFileFn = func(string) ([]byte, error) { return nil, errors.New("read fail") }
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
t.Fatalf("expected update success, got %v", err)
}
})
t.Run("WriteError", func(t *testing.T) {
readFileFn = func(string) ([]byte, error) { return []byte("new"), nil }
writeFileFn = func(string, []byte, os.FileMode) error { return errors.New("write fail") }
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
t.Fatalf("expected update success, got %v", err)
}
})
t.Run("RenameError", func(t *testing.T) {
readFileFn = func(string) ([]byte, error) { return []byte("new"), nil }
writeFileFn = os.WriteFile
renameFn = func(oldPath, newPath string) error {
if newPath == persistPath {
return errors.New("rename fail")
}
return os.Rename(oldPath, newPath)
}
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
t.Fatalf("expected update success, got %v", err)
}
})
t.Run("Success", func(t *testing.T) {
readFileFn = func(string) ([]byte, error) { return []byte("new"), nil }
writeFileFn = os.WriteFile
renameFn = os.Rename
if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil {
t.Fatalf("expected update success, got %v", err)
}
})
}
type errorReader struct {
sent bool
}
func (e *errorReader) Read(p []byte) (int, error) {
if e.sent {
return 0, errors.New("read fail")
}
e.sent = true
copy(p, testBinary())
return len(testBinary()), errors.New("read fail")
}