Improve internal package test coverage

This commit is contained in:
rcourtman
2025-12-29 17:25:21 +00:00
parent d07b471e40
commit c6bd8cb74c
40 changed files with 10280 additions and 469 deletions

View File

@@ -51,6 +51,25 @@ var requiredHostAgentBinaries = []HostAgentBinary{
var downloadMu sync.Mutex
var (
httpClient = &http.Client{Timeout: 2 * time.Minute}
downloadURLForVersion = func(version string) string {
return fmt.Sprintf("https://github.com/rcourtman/Pulse/releases/download/%[1]s/pulse-%[1]s.tar.gz", version)
}
downloadAndInstallHostAgentBinariesFn = DownloadAndInstallHostAgentBinaries
findMissingHostAgentBinariesFn = findMissingHostAgentBinaries
mkdirAllFn = os.MkdirAll
createTempFn = os.CreateTemp
removeFn = os.Remove
openFileFn = os.Open
openFileModeFn = os.OpenFile
renameFn = os.Rename
symlinkFn = os.Symlink
copyFn = io.Copy
chmodFileFn = func(f *os.File, mode os.FileMode) error { return f.Chmod(mode) }
closeFileFn = func(f *os.File) error { return f.Close() }
)
// HostAgentSearchPaths returns the directories to search for host agent binaries.
func HostAgentSearchPaths() []string {
primary := strings.TrimSpace(os.Getenv("PULSE_BIN_DIR"))
@@ -77,7 +96,7 @@ func HostAgentSearchPaths() []string {
// The returned map contains any binaries that remain missing after the attempt.
func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
binDirs := HostAgentSearchPaths()
missing := findMissingHostAgentBinaries(binDirs)
missing := findMissingHostAgentBinariesFn(binDirs)
if len(missing) == 0 {
return nil
}
@@ -86,7 +105,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
defer downloadMu.Unlock()
// Re-check after acquiring the lock in case another goroutine restored them.
missing = findMissingHostAgentBinaries(binDirs)
missing = findMissingHostAgentBinariesFn(binDirs)
if len(missing) == 0 {
return nil
}
@@ -101,7 +120,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
Strs("missing_platforms", missingPlatforms).
Msg("Host agent binaries missing - attempting to download bundle from GitHub release")
if err := DownloadAndInstallHostAgentBinaries(version, binDirs[0]); err != nil {
if err := downloadAndInstallHostAgentBinariesFn(version, binDirs[0]); err != nil {
log.Error().
Err(err).
Str("target_dir", binDirs[0]).
@@ -110,7 +129,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
return missing
}
if remaining := findMissingHostAgentBinaries(binDirs); len(remaining) > 0 {
if remaining := findMissingHostAgentBinariesFn(binDirs); len(remaining) > 0 {
stillMissing := make([]string, 0, len(remaining))
for key := range remaining {
stillMissing = append(stillMissing, key)
@@ -133,19 +152,18 @@ func DownloadAndInstallHostAgentBinaries(version string, targetDir string) error
return fmt.Errorf("cannot download host agent bundle for non-release version %q", version)
}
if err := os.MkdirAll(targetDir, 0o755); err != nil {
if err := mkdirAllFn(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to ensure bin directory %s: %w", targetDir, err)
}
url := fmt.Sprintf("https://github.com/rcourtman/Pulse/releases/download/%[1]s/pulse-%[1]s.tar.gz", normalizedVersion)
tempFile, err := os.CreateTemp("", "pulse-host-agent-*.tar.gz")
url := downloadURLForVersion(normalizedVersion)
tempFile, err := createTempFn("", "pulse-host-agent-*.tar.gz")
if err != nil {
return fmt.Errorf("failed to create temporary archive file: %w", err)
}
defer os.Remove(tempFile.Name())
defer removeFn(tempFile.Name())
client := &http.Client{Timeout: 2 * time.Minute}
resp, err := client.Get(url)
resp, err := httpClient.Get(url)
if err != nil {
return fmt.Errorf("failed to download host agent bundle from %s: %w", url, err)
}
@@ -160,7 +178,7 @@ func DownloadAndInstallHostAgentBinaries(version string, targetDir string) error
return fmt.Errorf("failed to save host agent bundle: %w", err)
}
if err := tempFile.Close(); err != nil {
if err := closeFileFn(tempFile); err != nil {
return fmt.Errorf("failed to close temporary bundle file: %w", err)
}
@@ -204,7 +222,7 @@ func normalizeVersionTag(version string) string {
}
func extractHostAgentBinaries(archivePath, targetDir string) error {
file, err := os.Open(archivePath)
file, err := openFileFn(archivePath)
if err != nil {
return fmt.Errorf("failed to open host agent bundle: %w", err)
}
@@ -232,10 +250,6 @@ func extractHostAgentBinaries(archivePath, targetDir string) error {
return fmt.Errorf("failed to read host agent bundle: %w", err)
}
if header == nil {
continue
}
if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeSymlink {
continue
}
@@ -265,10 +279,10 @@ func extractHostAgentBinaries(archivePath, targetDir string) error {
}
for _, link := range symlinks {
if err := os.Remove(link.path); err != nil && !os.IsNotExist(err) {
if err := removeFn(link.path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to replace existing symlink %s: %w", link.path, err)
}
if err := os.Symlink(link.target, link.path); err != nil {
if err := symlinkFn(link.target, link.path); err != nil {
// Fallback: copy the referenced file if symlinks are not permitted
source := filepath.Join(targetDir, link.target)
if err := copyHostAgentFile(source, link.path); err != nil {
@@ -281,31 +295,31 @@ func extractHostAgentBinaries(archivePath, targetDir string) error {
}
func writeHostAgentFile(destination string, reader io.Reader, mode os.FileMode) error {
if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil {
if err := mkdirAllFn(filepath.Dir(destination), 0o755); err != nil {
return fmt.Errorf("failed to create directory for %s: %w", destination, err)
}
tmpFile, err := os.CreateTemp(filepath.Dir(destination), "pulse-host-agent-*")
tmpFile, err := createTempFn(filepath.Dir(destination), "pulse-host-agent-*")
if err != nil {
return fmt.Errorf("failed to create temporary file for %s: %w", destination, err)
}
defer os.Remove(tmpFile.Name())
defer removeFn(tmpFile.Name())
if _, err := io.Copy(tmpFile, reader); err != nil {
tmpFile.Close()
if _, err := copyFn(tmpFile, reader); err != nil {
closeFileFn(tmpFile)
return fmt.Errorf("failed to extract %s: %w", destination, err)
}
if err := tmpFile.Chmod(normalizeExecutableMode(mode)); err != nil {
tmpFile.Close()
if err := chmodFileFn(tmpFile, normalizeExecutableMode(mode)); err != nil {
closeFileFn(tmpFile)
return fmt.Errorf("failed to set permissions on %s: %w", destination, err)
}
if err := tmpFile.Close(); err != nil {
if err := closeFileFn(tmpFile); err != nil {
return fmt.Errorf("failed to finalize %s: %w", destination, err)
}
if err := os.Rename(tmpFile.Name(), destination); err != nil {
if err := renameFn(tmpFile.Name(), destination); err != nil {
return fmt.Errorf("failed to install %s: %w", destination, err)
}
@@ -313,23 +327,23 @@ func writeHostAgentFile(destination string, reader io.Reader, mode os.FileMode)
}
func copyHostAgentFile(source, destination string) error {
src, err := os.Open(source)
src, err := openFileFn(source)
if err != nil {
return fmt.Errorf("failed to open %s for fallback copy: %w", source, err)
}
defer src.Close()
if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil {
if err := mkdirAllFn(filepath.Dir(destination), 0o755); err != nil {
return fmt.Errorf("failed to prepare directory for %s: %w", destination, err)
}
dst, err := os.OpenFile(destination, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
dst, err := openFileModeFn(destination, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("failed to create fallback copy %s: %w", destination, err)
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
if _, err := copyFn(dst, src); err != nil {
return fmt.Errorf("failed to copy %s to %s: %w", source, destination, err)
}

View File

@@ -0,0 +1,698 @@
package agentbinaries
import (
"archive/tar"
"bytes"
"compress/gzip"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
type tarEntry struct {
name string
body []byte
mode int64
typeflag byte
linkname string
}
type errorReader struct{}
func (e *errorReader) Read([]byte) (int, error) {
return 0, errors.New("read fail")
}
func buildTarGz(t *testing.T, entries []tarEntry) []byte {
t.Helper()
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gzw)
for _, entry := range entries {
typeflag := entry.typeflag
if typeflag == 0 {
typeflag = tar.TypeReg
}
size := int64(len(entry.body))
if typeflag != tar.TypeReg {
size = 0
}
hdr := &tar.Header{
Name: entry.name,
Mode: entry.mode,
Size: size,
Typeflag: typeflag,
Linkname: entry.linkname,
}
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("WriteHeader: %v", err)
}
if typeflag == tar.TypeReg {
if _, err := tw.Write(entry.body); err != nil {
t.Fatalf("Write: %v", err)
}
}
}
if err := tw.Close(); err != nil {
t.Fatalf("tar close: %v", err)
}
if err := gzw.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
return buf.Bytes()
}
func saveHostAgentHooks() func() {
origRequired := requiredHostAgentBinaries
origDownloadFn := downloadAndInstallHostAgentBinariesFn
origFindMissing := findMissingHostAgentBinariesFn
origURL := downloadURLForVersion
origClient := httpClient
origMkdirAll := mkdirAllFn
origCreateTemp := createTempFn
origRemove := removeFn
origOpen := openFileFn
origOpenMode := openFileModeFn
origRename := renameFn
origSymlink := symlinkFn
origCopy := copyFn
origChmod := chmodFileFn
origClose := closeFileFn
return func() {
requiredHostAgentBinaries = origRequired
downloadAndInstallHostAgentBinariesFn = origDownloadFn
findMissingHostAgentBinariesFn = origFindMissing
downloadURLForVersion = origURL
httpClient = origClient
mkdirAllFn = origMkdirAll
createTempFn = origCreateTemp
removeFn = origRemove
openFileFn = origOpen
openFileModeFn = origOpenMode
renameFn = origRename
symlinkFn = origSymlink
copyFn = origCopy
chmodFileFn = origChmod
closeFileFn = origClose
}
}
func TestEnsureHostAgentBinaries_NoMissing(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
requiredHostAgentBinaries = []HostAgentBinary{
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
}
if err := os.WriteFile(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64"), []byte("bin"), 0o755); err != nil {
t.Fatalf("write file: %v", err)
}
origEnv := os.Getenv("PULSE_BIN_DIR")
t.Cleanup(func() {
if origEnv == "" {
os.Unsetenv("PULSE_BIN_DIR")
} else {
os.Setenv("PULSE_BIN_DIR", origEnv)
}
})
os.Setenv("PULSE_BIN_DIR", tmpDir)
if missing := EnsureHostAgentBinaries("v1.0.0"); missing != nil {
t.Fatalf("expected no missing binaries, got %v", missing)
}
}
func TestEnsureHostAgentBinaries_DownloadError(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
requiredHostAgentBinaries = []HostAgentBinary{
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
}
downloadAndInstallHostAgentBinariesFn = func(string, string) error {
return errors.New("download failed")
}
origEnv := os.Getenv("PULSE_BIN_DIR")
t.Cleanup(func() {
if origEnv == "" {
os.Unsetenv("PULSE_BIN_DIR")
} else {
os.Setenv("PULSE_BIN_DIR", origEnv)
}
})
os.Setenv("PULSE_BIN_DIR", tmpDir)
missing := EnsureHostAgentBinaries("v1.0.0")
if len(missing) != 1 {
t.Fatalf("expected missing map, got %v", missing)
}
}
func TestEnsureHostAgentBinaries_StillMissing(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
requiredHostAgentBinaries = []HostAgentBinary{
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
}
downloadAndInstallHostAgentBinariesFn = func(string, string) error { return nil }
origEnv := os.Getenv("PULSE_BIN_DIR")
t.Cleanup(func() {
if origEnv == "" {
os.Unsetenv("PULSE_BIN_DIR")
} else {
os.Setenv("PULSE_BIN_DIR", origEnv)
}
})
os.Setenv("PULSE_BIN_DIR", tmpDir)
missing := EnsureHostAgentBinaries("v1.0.0")
if len(missing) != 1 {
t.Fatalf("expected still missing")
}
}
func TestEnsureHostAgentBinaries_RecheckAfterLock(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
requiredHostAgentBinaries = []HostAgentBinary{
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
}
calls := 0
findMissingHostAgentBinariesFn = func([]string) map[string]HostAgentBinary {
calls++
if calls == 1 {
return map[string]HostAgentBinary{
"linux-amd64": requiredHostAgentBinaries[0],
}
}
return nil
}
origEnv := os.Getenv("PULSE_BIN_DIR")
t.Cleanup(func() {
if origEnv == "" {
os.Unsetenv("PULSE_BIN_DIR")
} else {
os.Setenv("PULSE_BIN_DIR", origEnv)
}
})
os.Setenv("PULSE_BIN_DIR", t.TempDir())
if result := EnsureHostAgentBinaries("v1.0.0"); result != nil {
t.Fatalf("expected nil after recheck, got %v", result)
}
}
func TestEnsureHostAgentBinaries_RestoreSuccess(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
requiredHostAgentBinaries = []HostAgentBinary{
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
}
downloadAndInstallHostAgentBinariesFn = func(string, string) error {
return os.WriteFile(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64"), []byte("bin"), 0o755)
}
origEnv := os.Getenv("PULSE_BIN_DIR")
t.Cleanup(func() {
if origEnv == "" {
os.Unsetenv("PULSE_BIN_DIR")
} else {
os.Setenv("PULSE_BIN_DIR", origEnv)
}
})
os.Setenv("PULSE_BIN_DIR", tmpDir)
if missing := EnsureHostAgentBinaries("v1.0.0"); missing != nil {
t.Fatalf("expected restore success")
}
}
func TestDownloadAndInstallHostAgentBinariesErrors(t *testing.T) {
t.Run("MkdirAllError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected mkdir error")
}
})
t.Run("CreateTempError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
createTempFn = func(string, string) (*os.File, error) { return nil, errors.New("temp fail") }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected temp error")
}
})
t.Run("DownloadError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("network")
}),
}
downloadURLForVersion = func(string) string { return "http://example/bundle.tar.gz" }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected download error")
}
})
t.Run("StatusError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad"))
}))
t.Cleanup(server.Close)
httpClient = server.Client()
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected status error")
}
})
t.Run("CopyError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("data"))
}))
t.Cleanup(server.Close)
httpClient = server.Client()
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected copy error")
}
})
t.Run("CopyReadError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
body := io.NopCloser(&errorReader{})
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: body,
}, nil
}),
}
downloadURLForVersion = func(string) string { return "http://example/bundle.tar.gz" }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected copy read error")
}
})
t.Run("CloseError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("data"))
}))
t.Cleanup(server.Close)
httpClient = server.Client()
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
closeFileFn = func(*os.File) error { return errors.New("close fail") }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected close error")
}
})
t.Run("ExtractError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("not a tarball"))
}))
t.Cleanup(server.Close)
httpClient = server.Client()
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected extract error")
}
})
}
func TestDownloadAndInstallHostAgentBinariesSuccess(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
payload := buildTarGz(t, []tarEntry{
{name: "README.md", body: []byte("skip"), mode: 0o644},
{name: "bin/", typeflag: tar.TypeDir, mode: 0o755},
{name: "bin/not-agent.txt", body: []byte("skip"), mode: 0o644},
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
{name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"},
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(payload)
}))
t.Cleanup(server.Close)
httpClient = server.Client()
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
symlinkFn = func(string, string) error { return errors.New("no symlink") }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", tmpDir); err != nil {
t.Fatalf("expected success, got %v", err)
}
if _, err := os.Stat(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64")); err != nil {
t.Fatalf("expected binary installed: %v", err)
}
if _, err := os.Stat(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64.exe")); err != nil {
t.Fatalf("expected symlink fallback copy: %v", err)
}
}
func TestExtractHostAgentBinariesErrors(t *testing.T) {
t.Run("OpenError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
openFileFn = func(string) (*os.File, error) { return nil, errors.New("open fail") }
if err := extractHostAgentBinaries("missing.tar.gz", t.TempDir()); err == nil {
t.Fatalf("expected open error")
}
})
t.Run("GzipError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmp := filepath.Join(t.TempDir(), "bad.gz")
if err := os.WriteFile(tmp, []byte("bad"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
if err := extractHostAgentBinaries(tmp, t.TempDir()); err == nil {
t.Fatalf("expected gzip error")
}
})
t.Run("TarReadError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
_, _ = gzw.Write([]byte("not a tar"))
_ = gzw.Close()
tmp := filepath.Join(t.TempDir(), "bad.tar.gz")
if err := os.WriteFile(tmp, buf.Bytes(), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
if err := extractHostAgentBinaries(tmp, t.TempDir()); err == nil {
t.Fatalf("expected tar read error")
}
})
}
func TestExtractHostAgentBinariesRemoveError(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
payload := buildTarGz(t, []tarEntry{
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
{name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"},
})
archive := filepath.Join(t.TempDir(), "bundle.tar.gz")
if err := os.WriteFile(archive, payload, 0o644); err != nil {
t.Fatalf("write: %v", err)
}
removeFn = func(string) error { return errors.New("remove fail") }
if err := extractHostAgentBinaries(archive, tmpDir); err == nil {
t.Fatalf("expected remove error")
}
}
func TestExtractHostAgentBinariesSymlinkCopyError(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
payload := buildTarGz(t, []tarEntry{
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
{name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"},
})
archive := filepath.Join(t.TempDir(), "bundle.tar.gz")
if err := os.WriteFile(archive, payload, 0o644); err != nil {
t.Fatalf("write: %v", err)
}
symlinkFn = func(string, string) error { return errors.New("no symlink") }
openFileFn = func(path string) (*os.File, error) {
if path == archive {
return os.Open(path)
}
return nil, errors.New("open fail")
}
if err := extractHostAgentBinaries(archive, tmpDir); err == nil {
t.Fatalf("expected symlink fallback error")
}
}
func TestWriteHostAgentFileErrors(t *testing.T) {
t.Run("MkdirAllError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
if err := writeHostAgentFile("dest", strings.NewReader("data"), 0o644); err == nil {
t.Fatalf("expected mkdir error")
}
})
t.Run("CreateTempError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
createTempFn = func(string, string) (*os.File, error) { return nil, errors.New("temp fail") }
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
t.Fatalf("expected temp error")
}
})
t.Run("CopyError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
t.Fatalf("expected copy error")
}
})
t.Run("ChmodError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
chmodFileFn = func(*os.File, os.FileMode) error { return errors.New("chmod fail") }
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
t.Fatalf("expected chmod error")
}
})
t.Run("CloseError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
closeFileFn = func(*os.File) error { return errors.New("close fail") }
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
t.Fatalf("expected close error")
}
})
t.Run("RenameError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
renameFn = func(string, string) error { return errors.New("rename fail") }
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
t.Fatalf("expected rename error")
}
})
}
func TestWriteHostAgentFileSuccess(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
dest := filepath.Join(t.TempDir(), "pulse-host-agent-linux-amd64")
if err := writeHostAgentFile(dest, strings.NewReader("data"), 0o644); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := os.Stat(dest); err != nil {
t.Fatalf("expected file written: %v", err)
}
}
func TestCopyHostAgentFileErrors(t *testing.T) {
t.Run("OpenError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
openFileFn = func(string) (*os.File, error) { return nil, errors.New("open fail") }
if err := copyHostAgentFile("missing", filepath.Join(t.TempDir(), "dest")); err == nil {
t.Fatalf("expected open error")
}
})
t.Run("MkdirAllError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
src := filepath.Join(t.TempDir(), "src")
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil {
t.Fatalf("expected mkdir error")
}
})
t.Run("OpenFileError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
src := filepath.Join(t.TempDir(), "src")
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
openFileModeFn = func(string, int, os.FileMode) (*os.File, error) {
return nil, errors.New("create fail")
}
if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil {
t.Fatalf("expected open file error")
}
})
t.Run("CopyError", func(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
src := filepath.Join(t.TempDir(), "src")
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil {
t.Fatalf("expected copy error")
}
})
}
func TestCopyHostAgentFileSuccess(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
src := filepath.Join(t.TempDir(), "src")
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
dest := filepath.Join(t.TempDir(), "dest")
if err := copyHostAgentFile(src, dest); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := os.Stat(dest); err != nil {
t.Fatalf("expected dest file: %v", err)
}
}
func TestExtractHostAgentBinariesWriteError(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
tmpDir := t.TempDir()
payload := buildTarGz(t, []tarEntry{
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
})
archive := filepath.Join(t.TempDir(), "bundle.tar.gz")
if err := os.WriteFile(archive, payload, 0o644); err != nil {
t.Fatalf("write: %v", err)
}
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
if err := extractHostAgentBinaries(archive, tmpDir); err == nil {
t.Fatalf("expected write error")
}
}
func TestDownloadAndInstallHostAgentBinaries_Context(t *testing.T) {
restore := saveHostAgentHooks()
t.Cleanup(restore)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Context().Err() != nil {
t.Fatalf("unexpected context error")
}
_, _ = w.Write([]byte("data"))
}))
t.Cleanup(server.Close)
httpClient = server.Client()
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
t.Fatalf("expected copy error")
}
}

View File

@@ -20,19 +20,26 @@ var upgrader = websocket.Upgrader{
},
}
var (
jsonMarshal = json.Marshal
pingInterval = 5 * time.Second
pingWriteWait = 5 * time.Second
readFileTimeout = 30 * time.Second
)
// Server manages WebSocket connections from agents
type Server struct {
mu sync.RWMutex
agents map[string]*agentConn // agentID -> connection
agents map[string]*agentConn // agentID -> connection
pendingReqs map[string]chan CommandResultPayload // requestID -> response channel
validateToken func(token string) bool
}
type agentConn struct {
conn *websocket.Conn
agent ConnectedAgent
writeMu sync.Mutex
done chan struct{}
conn *websocket.Conn
agent ConnectedAgent
writeMu sync.Mutex
done chan struct{}
}
// NewServer creates a new agent execution server
@@ -94,7 +101,7 @@ func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
}
// Parse registration payload
payloadBytes, err := json.Marshal(msg.Payload)
payloadBytes, err := jsonMarshal(msg.Payload)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal registration payload")
conn.Close()
@@ -258,7 +265,7 @@ func (s *Server) readLoop(ac *agentConn) {
}
func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
ticker := time.NewTicker(5 * time.Second)
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
// Track consecutive ping failures to detect dead connections faster
@@ -273,7 +280,7 @@ func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
return
case <-ticker.C:
ac.writeMu.Lock()
err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(pingWriteWait))
ac.writeMu.Unlock()
if err != nil {
consecutiveFailures++
@@ -283,14 +290,14 @@ func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
Str("hostname", ac.agent.Hostname).
Int("consecutive_failures", consecutiveFailures).
Msg("Failed to send ping to agent")
if consecutiveFailures >= maxConsecutiveFailures {
log.Error().
Str("agent_id", ac.agent.AgentID).
Str("hostname", ac.agent.Hostname).
Int("failures", consecutiveFailures).
Msg("Agent connection appears dead after multiple ping failures, closing connection")
// Close the connection - this will cause readLoop to exit and clean up
ac.conn.Close()
return
@@ -404,7 +411,7 @@ func (s *Server) ReadFile(ctx context.Context, agentID string, req ReadFilePaylo
}
// Wait for result
timeout := 30 * time.Second
timeout := readFileTimeout
select {
case result := <-respCh:
return &result, nil

View File

@@ -0,0 +1,537 @@
package agentexec
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type noHijackResponseWriter struct {
header http.Header
}
func (w *noHijackResponseWriter) Header() http.Header {
return w.header
}
func (w *noHijackResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *noHijackResponseWriter) WriteHeader(int) {}
func newConnPair(t *testing.T) (*websocket.Conn, *websocket.Conn, func()) {
t.Helper()
serverConnCh := make(chan *websocket.Conn, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
serverConnCh <- conn
}))
clientConn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
ts.Close()
t.Fatalf("Dial: %v", err)
}
var serverConn *websocket.Conn
select {
case serverConn = <-serverConnCh:
case <-time.After(2 * time.Second):
clientConn.Close()
ts.Close()
t.Fatal("timed out waiting for server connection")
}
cleanup := func() {
clientConn.Close()
serverConn.Close()
ts.Close()
}
return serverConn, clientConn, cleanup
}
func TestHandleWebSocket_UpgradeFailureAndDeadlineErrors(t *testing.T) {
s := NewServer(nil)
req := httptest.NewRequest(http.MethodGet, "http://example/ws", nil)
s.HandleWebSocket(&noHijackResponseWriter{header: make(http.Header)}, req)
}
func TestHandleWebSocket_RegistrationReadError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
conn.Close()
}
func TestHandleWebSocket_RegistrationMessageJSONError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte("{")); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on invalid JSON")
}
}
func TestHandleWebSocket_RegistrationPayloadMarshalError(t *testing.T) {
orig := jsonMarshal
t.Cleanup(func() { jsonMarshal = orig })
jsonMarshal = func(any) ([]byte, error) {
return nil, errors.New("boom")
}
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
},
})
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on marshal error")
}
}
func TestHandleWebSocket_RegistrationPayloadUnmarshalError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"agent_register","payload":"oops"}`)); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on invalid payload")
}
}
func TestHandleWebSocket_PongHandler(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
if err := conn.WriteControl(websocket.PongMessage, []byte("pong"), time.Now().Add(time.Second)); err != nil {
t.Fatalf("WriteControl pong: %v", err)
}
conn.Close()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestReadLoopDone(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
close(ac.done)
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
s.readLoop(ac)
if s.IsAgentConnected("a1") {
t.Fatalf("expected agent to be removed")
}
clientConn.Close()
}
func TestReadLoopUnexpectedCloseError(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
_ = clientConn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseProtocolError, "bye"),
time.Now().Add(time.Second),
)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
}
func TestReadLoopCommandResultBranches(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.pendingReqs["req-full"] = make(chan CommandResultPayload)
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
_ = clientConn.WriteMessage(websocket.TextMessage, []byte("{"))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":123}}`))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-full","success":true}}`))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-missing","success":true}}`))
clientConn.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
s.mu.Lock()
delete(s.pendingReqs, "req-full")
s.mu.Unlock()
}
func TestPingLoopSuccessAndStop(t *testing.T) {
origInterval := pingInterval
t.Cleanup(func() { pingInterval = origInterval })
pingInterval = 5 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
stop := make(chan struct{})
exited := make(chan struct{})
go func() {
s.pingLoop(ac, stop)
close(exited)
}()
time.Sleep(2 * pingInterval)
close(stop)
select {
case <-exited:
case <-time.After(2 * time.Second):
t.Fatalf("pingLoop did not exit")
}
}
func TestPingLoopFailuresClose(t *testing.T) {
origInterval := pingInterval
t.Cleanup(func() { pingInterval = origInterval })
pingInterval = 5 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
serverConn.Close()
stop := make(chan struct{})
exited := make(chan struct{})
go func() {
s.pingLoop(ac, stop)
close(exited)
}()
select {
case <-exited:
case <-time.After(2 * time.Second):
t.Fatalf("pingLoop did not exit after failures")
}
}
func TestSendMessageMarshalError(t *testing.T) {
s := NewServer(nil)
if err := s.sendMessage(nil, Message{Payload: make(chan int)}); err == nil {
t.Fatalf("expected marshal error")
}
}
func TestExecuteCommandSendError(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
serverConn.Close()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r1", Timeout: 1})
if err == nil {
t.Fatalf("expected send error")
}
}
func TestExecuteCommandTimeoutAndCancel(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-timeout", Timeout: 1})
if err == nil || !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error, got %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{RequestID: "r-cancel", Timeout: 1})
if err == nil {
t.Fatalf("expected cancel error")
}
}
func TestExecuteCommandDefaultTimeout(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
go func() {
for {
s.mu.RLock()
ch := s.pendingReqs["r-default"]
s.mu.RUnlock()
if ch != nil {
ch <- CommandResultPayload{RequestID: "r-default", Success: true}
return
}
time.Sleep(2 * time.Millisecond)
}
}()
result, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-default"})
if err != nil || result == nil || !result.Success {
t.Fatalf("expected success, got result=%v err=%v", result, err)
}
}
func TestReadFileRoundTrip(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
agentDone := make(chan error, 1)
go func() {
for {
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
agentDone <- err
return
}
if msg.Type != MsgTypeReadFile || msg.Payload == nil {
continue
}
var payload ReadFilePayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
agentDone <- err
return
}
agentDone <- conn.WriteJSON(Message{
Type: MsgTypeCommandResult,
Timestamp: time.Now(),
Payload: CommandResultPayload{
RequestID: payload.RequestID,
Success: true,
Stdout: "data",
ExitCode: 0,
},
})
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-1", Path: "/etc/hosts"})
if err != nil || result == nil || result.Stdout != "data" {
t.Fatalf("unexpected read file result=%v err=%v", result, err)
}
if err := <-agentDone; err != nil {
t.Fatalf("agent error: %v", err)
}
}
func TestReadFileTimeoutCancelAndSendError(t *testing.T) {
origTimeout := readFileTimeout
t.Cleanup(func() { readFileTimeout = origTimeout })
readFileTimeout = 10 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-timeout"}); err == nil {
t.Fatalf("expected timeout error")
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-cancel"}); err == nil {
t.Fatalf("expected cancel error")
}
serverConn.Close()
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-send"}); err == nil {
t.Fatalf("expected send error")
}
}

View File

@@ -42,6 +42,9 @@ func TestConnectedAgentLookups(t *testing.T) {
if !ok || agentID != "a2" {
t.Fatalf("expected GetAgentForHost(host2) = (a2, true), got (%q, %v)", agentID, ok)
}
if _, ok := s.GetAgentForHost("missing"); ok {
t.Fatalf("expected missing host to return false")
}
agents := s.GetConnectedAgents()
if len(agents) != 2 {

File diff suppressed because it is too large Load Diff

View File

@@ -8,13 +8,15 @@ import (
"syscall"
)
var execFn = syscall.Exec
// restartProcess replaces the current process with a new instance.
// On Unix-like systems, this uses syscall.Exec for an in-place restart.
func restartProcess(execPath string) error {
args := os.Args
env := os.Environ()
if err := syscall.Exec(execPath, args, env); err != nil {
if err := execFn(execPath, args, env); err != nil {
return fmt.Errorf("failed to restart: %w", err)
}

View File

@@ -32,6 +32,24 @@ const (
downloadTimeout = 5 * time.Minute
)
var (
maxBinarySizeBytes int64 = maxBinarySize
runtimeGOOS = runtime.GOOS
runtimeGOARCH = runtime.GOARCH
unameCommand = func() ([]byte, error) { return exec.Command("uname", "-m").Output() }
unraidVersionPath = "/etc/unraid-version"
unraidPersistentPathFn = unraidPersistentPath
restartProcessFn = restartProcess
osExecutableFn = os.Executable
evalSymlinksFn = filepath.EvalSymlinks
createTempFn = os.CreateTemp
chmodFn = os.Chmod
renameFn = os.Rename
closeFileFn = func(f *os.File) error { return f.Close() }
readFileFn = os.ReadFile
writeFileFn = os.WriteFile
)
// Config holds the configuration for the updater.
type Config struct {
// PulseURL is the base URL of the Pulse server (e.g., "https://pulse.example.com:7655")
@@ -66,6 +84,8 @@ type Updater struct {
logger zerolog.Logger
performUpdateFn func(context.Context) error
initialDelay time.Duration
newTicker func(time.Duration) *time.Ticker
}
// New creates a new Updater with the given configuration.
@@ -95,6 +115,8 @@ func New(cfg Config) *Updater {
logger: logger,
}
u.performUpdateFn = u.performUpdate
u.initialDelay = 5 * time.Second
u.newTicker = time.NewTicker
return u
}
@@ -114,11 +136,11 @@ func (u *Updater) RunLoop(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
case <-time.After(u.initialDelay):
u.CheckAndUpdate(ctx)
}
ticker := time.NewTicker(u.cfg.CheckInterval)
ticker := u.newTicker(u.cfg.CheckInterval)
defer ticker.Stop()
for {
@@ -228,7 +250,7 @@ func (u *Updater) getServerVersion(ctx context.Context) (string, error) {
// isUnraid checks if we're running on Unraid by looking for /etc/unraid-version
func isUnraid() bool {
_, err := os.Stat("/etc/unraid-version")
_, err := os.Stat(unraidVersionPath)
return err == nil
}
@@ -245,7 +267,7 @@ func verifyBinaryMagic(path string) error {
return fmt.Errorf("failed to read magic bytes: %w", err)
}
switch runtime.GOOS {
switch runtimeGOOS {
case "linux":
// ELF magic: 0x7f 'E' 'L' 'F'
if magic[0] == 0x7f && magic[1] == 'E' && magic[2] == 'L' && magic[3] == 'F' {
@@ -286,10 +308,14 @@ func unraidPersistentPath(agentName string) string {
// performUpdate downloads and installs the new agent binary.
func (u *Updater) performUpdate(ctx context.Context) error {
execPath, err := os.Executable()
execPath, err := osExecutableFn()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
return u.performUpdateWithExecPath(ctx, execPath)
}
func (u *Updater) performUpdateWithExecPath(ctx context.Context, execPath string) error {
// Build download URL
downloadBase := fmt.Sprintf("%s/download/%s", strings.TrimRight(u.cfg.PulseURL, "/"), u.cfg.AgentName)
@@ -303,7 +329,7 @@ func (u *Updater) performUpdate(ctx context.Context) error {
candidates = append(candidates, downloadBase)
var resp *http.Response
var lastErr error
lastErr := errors.New("failed to download binary")
for _, url := range candidates {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
@@ -335,9 +361,6 @@ func (u *Updater) performUpdate(ctx context.Context) error {
}
if resp == nil {
if lastErr == nil {
lastErr = errors.New("failed to download binary")
}
return lastErr
}
defer resp.Body.Close()
@@ -346,7 +369,7 @@ func (u *Updater) performUpdate(ctx context.Context) error {
checksumHeader := strings.TrimSpace(resp.Header.Get("X-Checksum-Sha256"))
// Resolve symlinks to get the real path for atomic rename
realExecPath, err := filepath.EvalSymlinks(execPath)
realExecPath, err := evalSymlinksFn(execPath)
if err != nil {
// Fall back to original path if symlink resolution fails
realExecPath = execPath
@@ -355,7 +378,7 @@ func (u *Updater) performUpdate(ctx context.Context) error {
// Create temporary file in the same directory as the target binary
// to ensure atomic rename works (os.Rename fails across filesystems)
targetDir := filepath.Dir(realExecPath)
tmpFile, err := os.CreateTemp(targetDir, u.cfg.AgentName+"-*.tmp")
tmpFile, err := createTempFn(targetDir, u.cfg.AgentName+"-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
@@ -364,17 +387,17 @@ func (u *Updater) performUpdate(ctx context.Context) error {
// Write downloaded binary with checksum calculation and size limit
hasher := sha256.New()
limitedReader := io.LimitReader(resp.Body, maxBinarySize+1) // +1 to detect overflow
limitedReader := io.LimitReader(resp.Body, maxBinarySizeBytes+1) // +1 to detect overflow
written, err := io.Copy(tmpFile, io.TeeReader(limitedReader, hasher))
if err != nil {
tmpFile.Close()
closeFileFn(tmpFile)
return fmt.Errorf("failed to write binary: %w", err)
}
if written > maxBinarySize {
tmpFile.Close()
return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySize)
if written > maxBinarySizeBytes {
closeFileFn(tmpFile)
return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySizeBytes)
}
if err := tmpFile.Close(); err != nil {
if err := closeFileFn(tmpFile); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
@@ -397,19 +420,19 @@ func (u *Updater) performUpdate(ctx context.Context) error {
u.logger.Debug().Str("checksum", downloadChecksum).Msg("Checksum verified")
// Make executable
if err := os.Chmod(tmpPath, 0755); err != nil {
if err := chmodFn(tmpPath, 0755); err != nil {
return fmt.Errorf("failed to chmod: %w", err)
}
// Atomic replacement with backup (use realExecPath for rename operations)
backupPath := realExecPath + ".backup"
if err := os.Rename(realExecPath, backupPath); err != nil {
if err := renameFn(realExecPath, backupPath); err != nil {
return fmt.Errorf("failed to backup current binary: %w", err)
}
if err := os.Rename(tmpPath, realExecPath); err != nil {
if err := renameFn(tmpPath, realExecPath); err != nil {
// Restore backup on failure
os.Rename(backupPath, realExecPath)
renameFn(backupPath, realExecPath)
return fmt.Errorf("failed to replace binary: %w", err)
}
@@ -418,26 +441,26 @@ func (u *Updater) performUpdate(ctx context.Context) error {
// Write previous version to a file so the agent can report "updated from X" on next start
updateInfoPath := filepath.Join(targetDir, ".pulse-update-info")
_ = os.WriteFile(updateInfoPath, []byte(u.cfg.CurrentVersion), 0644)
_ = writeFileFn(updateInfoPath, []byte(u.cfg.CurrentVersion), 0644)
// On Unraid, also update the persistent copy on the flash drive
// This ensures the update survives reboots
if isUnraid() {
persistPath := unraidPersistentPath(u.cfg.AgentName)
persistPath := unraidPersistentPathFn(u.cfg.AgentName)
if _, err := os.Stat(persistPath); err == nil {
// Persistent path exists, update it
u.logger.Debug().Str("path", persistPath).Msg("Updating Unraid persistent binary")
// Read the newly installed binary
newBinary, err := os.ReadFile(execPath)
newBinary, err := readFileFn(execPath)
if err != nil {
u.logger.Warn().Err(err).Msg("Failed to read new binary for Unraid persistence")
} else {
// Write to persistent storage (atomic via temp file)
tmpPersist := persistPath + ".tmp"
if err := os.WriteFile(tmpPersist, newBinary, 0644); err != nil {
if err := writeFileFn(tmpPersist, newBinary, 0644); err != nil {
u.logger.Warn().Err(err).Msg("Failed to write Unraid persistent binary")
} else if err := os.Rename(tmpPersist, persistPath); err != nil {
} else if err := renameFn(tmpPersist, persistPath); err != nil {
u.logger.Warn().Err(err).Msg("Failed to rename Unraid persistent binary")
os.Remove(tmpPersist)
} else {
@@ -448,19 +471,19 @@ func (u *Updater) performUpdate(ctx context.Context) error {
}
// Restart the process using platform-specific implementation
return restartProcess(execPath)
return restartProcessFn(execPath)
}
// GetUpdatedFromVersion checks if the agent was recently updated and returns the previous version.
// Returns empty string if no update info exists. Clears the info file after reading.
func GetUpdatedFromVersion() string {
execPath, err := os.Executable()
execPath, err := osExecutableFn()
if err != nil {
return ""
}
// Resolve symlinks to get the real path
realExecPath, err := filepath.EvalSymlinks(execPath)
realExecPath, err := evalSymlinksFn(execPath)
if err != nil {
realExecPath = execPath
}
@@ -479,8 +502,8 @@ func GetUpdatedFromVersion() string {
// determineArch returns the architecture string for download URLs (e.g., "linux-amd64", "darwin-arm64").
func determineArch() string {
os := runtime.GOOS
arch := runtime.GOARCH
os := runtimeGOOS
arch := runtimeGOARCH
// Normalize architecture
switch arch {
@@ -497,7 +520,7 @@ func determineArch() string {
}
// Fall back to uname for edge cases on unknown OS
out, err := exec.Command("uname", "-m").Output()
out, err := unameCommand()
if err != nil {
return ""
}

View File

@@ -40,68 +40,68 @@ const (
// HealthFactor represents a single component affecting health
type HealthFactor struct {
Name string `json:"name"`
Impact float64 `json:"impact"` // -1 to 1, negative is bad
Impact float64 `json:"impact"` // -1 to 1, negative is bad
Description string `json:"description"`
Category string `json:"category"` // finding, prediction, baseline, incident
}
// HealthScore represents the overall health of a resource or system
type HealthScore struct {
Score float64 `json:"score"` // 0-100
Grade HealthGrade `json:"grade"` // A, B, C, D, F
Trend HealthTrend `json:"trend"` // improving, stable, declining
Score float64 `json:"score"` // 0-100
Grade HealthGrade `json:"grade"` // A, B, C, D, F
Trend HealthTrend `json:"trend"` // improving, stable, declining
Factors []HealthFactor `json:"factors"` // What's affecting the score
Prediction string `json:"prediction"` // Human-readable outlook
Prediction string `json:"prediction"` // Human-readable outlook
}
// ResourceIntelligence aggregates all AI knowledge about a single resource
type ResourceIntelligence struct {
ResourceID string `json:"resource_id"`
ResourceName string `json:"resource_name,omitempty"`
ResourceType string `json:"resource_type,omitempty"`
Health HealthScore `json:"health"`
ActiveFindings []*Finding `json:"active_findings,omitempty"`
Predictions []patterns.FailurePrediction `json:"predictions,omitempty"`
Dependencies []string `json:"dependencies,omitempty"` // Resources this depends on
Dependents []string `json:"dependents,omitempty"` // Resources that depend on this
Correlations []*correlation.Correlation `json:"correlations,omitempty"`
ResourceID string `json:"resource_id"`
ResourceName string `json:"resource_name,omitempty"`
ResourceType string `json:"resource_type,omitempty"`
Health HealthScore `json:"health"`
ActiveFindings []*Finding `json:"active_findings,omitempty"`
Predictions []patterns.FailurePrediction `json:"predictions,omitempty"`
Dependencies []string `json:"dependencies,omitempty"` // Resources this depends on
Dependents []string `json:"dependents,omitempty"` // Resources that depend on this
Correlations []*correlation.Correlation `json:"correlations,omitempty"`
Baselines map[string]*baseline.FlatBaseline `json:"baselines,omitempty"`
Anomalies []AnomalyReport `json:"anomalies,omitempty"`
RecentIncidents []*memory.Incident `json:"recent_incidents,omitempty"`
Knowledge *knowledge.GuestKnowledge `json:"knowledge,omitempty"`
NoteCount int `json:"note_count"`
Anomalies []AnomalyReport `json:"anomalies,omitempty"`
RecentIncidents []*memory.Incident `json:"recent_incidents,omitempty"`
Knowledge *knowledge.GuestKnowledge `json:"knowledge,omitempty"`
NoteCount int `json:"note_count"`
}
// AnomalyReport describes a metric that's deviating from baseline
type AnomalyReport struct {
Metric string `json:"metric"`
CurrentValue float64 `json:"current_value"`
BaselineMean float64 `json:"baseline_mean"`
ZScore float64 `json:"z_score"`
Metric string `json:"metric"`
CurrentValue float64 `json:"current_value"`
BaselineMean float64 `json:"baseline_mean"`
ZScore float64 `json:"z_score"`
Severity baseline.AnomalySeverity `json:"severity"`
Description string `json:"description"`
Description string `json:"description"`
}
// IntelligenceSummary provides a system-wide intelligence overview
type IntelligenceSummary struct {
Timestamp time.Time `json:"timestamp"`
OverallHealth HealthScore `json:"overall_health"`
Timestamp time.Time `json:"timestamp"`
OverallHealth HealthScore `json:"overall_health"`
// Findings summary
FindingsCount FindingsCounts `json:"findings_count"`
TopFindings []*Finding `json:"top_findings,omitempty"` // Most critical
FindingsCount FindingsCounts `json:"findings_count"`
TopFindings []*Finding `json:"top_findings,omitempty"` // Most critical
// Predictions
PredictionsCount int `json:"predictions_count"`
UpcomingRisks []patterns.FailurePrediction `json:"upcoming_risks,omitempty"`
PredictionsCount int `json:"predictions_count"`
UpcomingRisks []patterns.FailurePrediction `json:"upcoming_risks,omitempty"`
// Recent activity
RecentChangesCount int `json:"recent_changes_count"`
RecentRemediations []memory.RemediationRecord `json:"recent_remediations,omitempty"`
RecentChangesCount int `json:"recent_changes_count"`
RecentRemediations []memory.RemediationRecord `json:"recent_remediations,omitempty"`
// Learning progress
Learning LearningStats `json:"learning"`
// Resources needing attention
ResourcesAtRisk []ResourceRiskSummary `json:"resources_at_risk,omitempty"`
}
@@ -137,7 +137,7 @@ type ResourceRiskSummary struct {
// Intelligence orchestrates all AI subsystems into a unified system
type Intelligence struct {
mu sync.RWMutex
// Core subsystems
findings *FindingsStore
patterns *patterns.Detector
@@ -147,10 +147,13 @@ type Intelligence struct {
knowledge *knowledge.Store
changes *memory.ChangeDetector
remediations *memory.RemediationLog
// State access
stateProvider StateProvider
// Optional hook for anomaly detection (used by patrol integration/tests)
anomalyDetector func(resourceID string) []AnomalyReport
// Configuration
dataDir string
}
@@ -180,7 +183,7 @@ func (i *Intelligence) SetSubsystems(
) {
i.mu.Lock()
defer i.mu.Unlock()
i.findings = findings
i.patterns = patternsDetector
i.correlations = correlationsDetector
@@ -202,45 +205,45 @@ func (i *Intelligence) SetStateProvider(sp StateProvider) {
func (i *Intelligence) GetSummary() *IntelligenceSummary {
i.mu.RLock()
defer i.mu.RUnlock()
summary := &IntelligenceSummary{
Timestamp: time.Now(),
}
// Aggregate findings
if i.findings != nil {
all := i.findings.GetActive(FindingSeverityInfo) // Get all active findings
summary.FindingsCount = i.countFindings(all)
summary.TopFindings = i.getTopFindings(all, 5)
}
// Aggregate predictions
if i.patterns != nil {
predictions := i.patterns.GetPredictions()
summary.PredictionsCount = len(predictions)
summary.UpcomingRisks = i.getUpcomingRisks(predictions, 5)
}
// Aggregate recent activity
if i.changes != nil {
recent := i.changes.GetRecentChanges(100, time.Now().Add(-24*time.Hour))
summary.RecentChangesCount = len(recent)
}
if i.remediations != nil {
recent := i.remediations.GetRecentRemediations(5, time.Now().Add(-24*time.Hour))
summary.RecentRemediations = recent
}
// Learning stats
summary.Learning = i.getLearningStats()
// Calculate overall health
summary.OverallHealth = i.calculateOverallHealth(summary)
// Resources at risk
summary.ResourcesAtRisk = i.getResourcesAtRisk(5)
return summary
}
@@ -248,11 +251,11 @@ func (i *Intelligence) GetSummary() *IntelligenceSummary {
func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntelligence {
i.mu.RLock()
defer i.mu.RUnlock()
intel := &ResourceIntelligence{
ResourceID: resourceID,
}
// Active findings
if i.findings != nil {
intel.ActiveFindings = i.findings.GetByResource(resourceID)
@@ -261,19 +264,19 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
intel.ResourceType = intel.ActiveFindings[0].ResourceType
}
}
// Predictions
if i.patterns != nil {
intel.Predictions = i.patterns.GetPredictionsForResource(resourceID)
}
// Correlations and dependencies
if i.correlations != nil {
intel.Correlations = i.correlations.GetCorrelationsForResource(resourceID)
intel.Dependencies = i.correlations.GetDependsOn(resourceID)
intel.Dependents = i.correlations.GetDependencies(resourceID)
}
// Baselines
if i.baselines != nil {
if rb, ok := i.baselines.GetResourceBaseline(resourceID); ok {
@@ -290,12 +293,12 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
}
}
}
// Recent incidents
if i.incidents != nil {
intel.RecentIncidents = i.incidents.ListIncidentsByResource(resourceID, 5)
}
// Knowledge
if i.knowledge != nil {
if k, err := i.knowledge.GetKnowledge(resourceID); err == nil && k != nil {
@@ -309,10 +312,10 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
}
}
}
// Calculate health score
intel.Health = i.calculateResourceHealth(intel)
return intel
}
@@ -320,49 +323,49 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
func (i *Intelligence) FormatContext(resourceID string) string {
i.mu.RLock()
defer i.mu.RUnlock()
var sections []string
// Knowledge (most important - what we've learned)
if i.knowledge != nil {
if ctx := i.knowledge.FormatForContext(resourceID); ctx != "" {
sections = append(sections, ctx)
}
}
// Baselines (what's normal for this resource)
if i.baselines != nil {
if ctx := i.formatBaselinesForContext(resourceID); ctx != "" {
sections = append(sections, ctx)
}
}
// Current anomalies
if anomalies := i.detectCurrentAnomalies(resourceID); len(anomalies) > 0 {
sections = append(sections, i.formatAnomaliesForContext(anomalies))
}
// Patterns/Predictions
if i.patterns != nil {
if ctx := i.patterns.FormatForContext(resourceID); ctx != "" {
sections = append(sections, ctx)
}
}
// Correlations
if i.correlations != nil {
if ctx := i.correlations.FormatForContext(resourceID); ctx != "" {
sections = append(sections, ctx)
}
}
// Incidents
if i.incidents != nil {
if ctx := i.incidents.FormatForResource(resourceID, 5); ctx != "" {
sections = append(sections, ctx)
}
}
return strings.Join(sections, "\n")
}
@@ -370,37 +373,37 @@ func (i *Intelligence) FormatContext(resourceID string) string {
func (i *Intelligence) FormatGlobalContext() string {
i.mu.RLock()
defer i.mu.RUnlock()
var sections []string
// All saved knowledge (limited)
if i.knowledge != nil {
if ctx := i.knowledge.FormatAllForContext(); ctx != "" {
sections = append(sections, ctx)
}
}
// Recent incidents across infrastructure
if i.incidents != nil {
if ctx := i.incidents.FormatForPatrol(8); ctx != "" {
sections = append(sections, ctx)
}
}
// Top correlations
if i.correlations != nil {
if ctx := i.correlations.FormatForContext(""); ctx != "" {
sections = append(sections, ctx)
}
}
// Top predictions
if i.patterns != nil {
if ctx := i.patterns.FormatForContext(""); ctx != "" {
sections = append(sections, ctx)
}
}
return strings.Join(sections, "\n")
}
@@ -408,11 +411,11 @@ func (i *Intelligence) FormatGlobalContext() string {
func (i *Intelligence) RecordLearning(resourceID, resourceName, resourceType, title, content string) error {
i.mu.RLock()
defer i.mu.RUnlock()
if i.knowledge == nil {
return nil
}
return i.knowledge.SaveNote(resourceID, resourceName, resourceType, "learning", title, content)
}
@@ -420,11 +423,11 @@ func (i *Intelligence) RecordLearning(resourceID, resourceName, resourceType, ti
func (i *Intelligence) CheckBaselinesForResource(resourceID string, metrics map[string]float64) []AnomalyReport {
i.mu.RLock()
defer i.mu.RUnlock()
if i.baselines == nil {
return nil
}
var anomalies []AnomalyReport
for metric, value := range metrics {
severity, zScore, bl := i.baselines.CheckAnomaly(resourceID, metric, value)
@@ -439,7 +442,7 @@ func (i *Intelligence) CheckBaselinesForResource(resourceID string, metrics map[
})
}
}
return anomalies
}
@@ -452,7 +455,7 @@ func (i *Intelligence) CreatePredictionFinding(pred patterns.FailurePrediction)
if pred.Confidence > 0.8 && pred.DaysUntil < 1 {
severity = FindingSeverityCritical
}
return &Finding{
ID: fmt.Sprintf("pred-%s-%s", pred.ResourceID, pred.EventType),
Key: fmt.Sprintf("prediction:%s:%s", pred.ResourceID, pred.EventType),
@@ -493,7 +496,7 @@ func (i *Intelligence) getTopFindings(findings []*Finding, limit int) []*Finding
if len(findings) == 0 {
return nil
}
// Sort by severity (critical first) then by detection time (newest first)
sorted := make([]*Finding, len(findings))
copy(sorted, findings)
@@ -505,7 +508,7 @@ func (i *Intelligence) getTopFindings(findings []*Finding, limit int) []*Finding
}
return sorted[a].DetectedAt.After(sorted[b].DetectedAt)
})
if len(sorted) > limit {
sorted = sorted[:limit]
}
@@ -531,7 +534,7 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction
if len(predictions) == 0 {
return nil
}
// Filter to next 7 days and sort by days until
var upcoming []patterns.FailurePrediction
for _, p := range predictions {
@@ -539,11 +542,11 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction
upcoming = append(upcoming, p)
}
}
sort.Slice(upcoming, func(a, b int) bool {
return upcoming[a].DaysUntil < upcoming[b].DaysUntil
})
if len(upcoming) > limit {
upcoming = upcoming[:limit]
}
@@ -552,7 +555,7 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction
func (i *Intelligence) getLearningStats() LearningStats {
stats := LearningStats{}
if i.knowledge != nil {
guests, _ := i.knowledge.ListGuests()
for _, guestID := range guests {
@@ -562,27 +565,27 @@ func (i *Intelligence) getLearningStats() LearningStats {
}
}
}
if i.baselines != nil {
stats.ResourcesWithBaselines = i.baselines.ResourceCount()
}
if i.patterns != nil {
p := i.patterns.GetPatterns()
stats.PatternsDetected = len(p)
}
if i.correlations != nil {
c := i.correlations.GetCorrelations()
stats.CorrelationsLearned = len(c)
}
if i.incidents != nil {
// Count is not available, so we skip this stat for now
// Could be added to IncidentStore if needed
stats.IncidentsTracked = 0
}
return stats
}
@@ -593,7 +596,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
Trend: HealthTrendStable,
Factors: []HealthFactor{},
}
// Deduct for findings
if summary.FindingsCount.Critical > 0 {
impact := float64(summary.FindingsCount.Critical) * 20
@@ -608,7 +611,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
Category: "finding",
})
}
if summary.FindingsCount.Warning > 0 {
impact := float64(summary.FindingsCount.Warning) * 10
if impact > 20 {
@@ -622,7 +625,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
Category: "finding",
})
}
// Deduct for imminent predictions
for _, pred := range summary.UpcomingRisks {
if pred.DaysUntil < 3 && pred.Confidence > 0.7 {
@@ -636,7 +639,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
})
}
}
// Bonus for learning progress
if summary.Learning.ResourcesWithKnowledge > 5 {
bonus := 5.0
@@ -648,7 +651,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
Category: "learning",
})
}
// Clamp score
if health.Score < 0 {
health.Score = 0
@@ -656,13 +659,13 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
if health.Score > 100 {
health.Score = 100
}
// Assign grade
health.Grade = scoreToGrade(health.Score)
// Generate prediction text
health.Prediction = i.generateHealthPrediction(health, summary)
return health
}
@@ -673,7 +676,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
Trend: HealthTrendStable,
Factors: []HealthFactor{},
}
// Deduct for active findings
for _, f := range intel.ActiveFindings {
if f == nil {
@@ -698,7 +701,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
Category: "finding",
})
}
// Deduct for predictions
for _, p := range intel.Predictions {
if p.DaysUntil < 7 && p.Confidence > 0.5 {
@@ -712,7 +715,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
})
}
}
// Deduct for anomalies
for _, a := range intel.Anomalies {
var impact float64
@@ -734,7 +737,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
Category: "baseline",
})
}
// Bonus for having knowledge
if intel.NoteCount > 0 {
bonus := 2.0
@@ -746,7 +749,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
Category: "learning",
})
}
// Clamp
if health.Score < 0 {
health.Score = 0
@@ -754,9 +757,9 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
if health.Score > 100 {
health.Score = 100
}
health.Grade = scoreToGrade(health.Score)
return health
}
@@ -779,21 +782,21 @@ func (i *Intelligence) generateHealthPrediction(health HealthScore, summary *Int
if health.Grade == HealthGradeA {
return "Infrastructure is healthy with no significant issues detected."
}
if summary.FindingsCount.Critical > 0 {
return fmt.Sprintf("Immediate attention required: %d critical issues.", summary.FindingsCount.Critical)
}
if len(summary.UpcomingRisks) > 0 {
risk := summary.UpcomingRisks[0]
return fmt.Sprintf("Predicted %s event on resource within %.1f days (%.0f%% confidence).",
risk.EventType, risk.DaysUntil, risk.Confidence*100)
}
if summary.FindingsCount.Warning > 0 {
return fmt.Sprintf("%d warnings should be addressed soon to maintain stability.", summary.FindingsCount.Warning)
}
return "Infrastructure is stable with minor issues to monitor."
}
@@ -801,16 +804,13 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
if i.findings == nil {
return nil
}
// Group findings by resource
byResource := make(map[string][]*Finding)
for _, f := range i.findings.GetActive(FindingSeverityInfo) {
if f == nil {
continue
}
byResource[f.ResourceID] = append(byResource[f.ResourceID], f)
}
// Calculate risk for each resource
type resourceRisk struct {
id string
@@ -819,13 +819,9 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
score float64
top string
}
var risks []resourceRisk
for id, findings := range byResource {
if len(findings) == 0 {
continue
}
score := 0.0
var topFinding *Finding
for _, f := range findings {
@@ -843,7 +839,7 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
topFinding = f
}
}
if score > 0 && topFinding != nil {
risks = append(risks, resourceRisk{
id: id,
@@ -854,16 +850,16 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
})
}
}
// Sort by risk score descending
sort.Slice(risks, func(a, b int) bool {
return risks[a].score > risks[b].score
})
if len(risks) > limit {
risks = risks[:limit]
}
// Convert to summaries
var summaries []ResourceRiskSummary
for _, r := range risks {
@@ -879,11 +875,14 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
TopIssue: r.top,
})
}
return summaries
}
func (i *Intelligence) detectCurrentAnomalies(resourceID string) []AnomalyReport {
if i.anomalyDetector != nil {
return i.anomalyDetector(resourceID)
}
// This would be called with current metrics from state
// For now, return empty - will be integrated with patrol
return nil
@@ -893,21 +892,21 @@ func (i *Intelligence) formatBaselinesForContext(resourceID string) string {
if i.baselines == nil {
return ""
}
rb, ok := i.baselines.GetResourceBaseline(resourceID)
if !ok || len(rb.Metrics) == 0 {
return ""
}
var lines []string
lines = append(lines, "\n## Learned Baselines")
lines = append(lines, "Normal operating ranges for this resource:")
for metric, mb := range rb.Metrics {
lines = append(lines, fmt.Sprintf("- %s: mean %.1f, stddev %.1f (samples: %d)",
metric, mb.Mean, mb.StdDev, mb.SampleCount))
}
return strings.Join(lines, "\n")
}
@@ -915,15 +914,15 @@ func (i *Intelligence) formatAnomaliesForContext(anomalies []AnomalyReport) stri
if len(anomalies) == 0 {
return ""
}
var lines []string
lines = append(lines, "\n## Current Anomalies")
lines = append(lines, "Metrics deviating from normal:")
for _, a := range anomalies {
lines = append(lines, fmt.Sprintf("- %s: %s", a.Metric, a.Description))
}
return strings.Join(lines, "\n")
}

View File

@@ -0,0 +1,552 @@
package ai
import (
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/ai/baseline"
"github.com/rcourtman/pulse-go-rewrite/internal/ai/correlation"
"github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge"
"github.com/rcourtman/pulse-go-rewrite/internal/ai/memory"
"github.com/rcourtman/pulse-go-rewrite/internal/ai/patterns"
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
)
func TestIntelligence_formatBaselinesForContext(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
store := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
if err := store.Learn("res-1", "vm", "cpu", []baseline.MetricPoint{{Value: 12.5}}); err != nil {
t.Fatalf("Learn: %v", err)
}
intel.SetSubsystems(nil, nil, nil, store, nil, nil, nil, nil)
ctx := intel.formatBaselinesForContext("res-1")
if !strings.Contains(ctx, "Learned Baselines") {
t.Fatalf("expected baseline header, got %q", ctx)
}
if !strings.Contains(ctx, "cpu: mean") {
t.Fatalf("expected cpu baseline line, got %q", ctx)
}
if got := intel.formatBaselinesForContext("missing"); got != "" {
t.Errorf("expected empty context for missing baseline, got %q", got)
}
empty := NewIntelligence(IntelligenceConfig{})
if got := empty.formatBaselinesForContext("res-1"); got != "" {
t.Errorf("expected empty context with no baseline store, got %q", got)
}
}
func TestIntelligence_formatAnomaliesForContext(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
if got := intel.formatAnomaliesForContext(nil); got != "" {
t.Errorf("expected empty anomalies context, got %q", got)
}
anomalies := []AnomalyReport{{Metric: "cpu", Description: "CPU high"}}
ctx := intel.formatAnomaliesForContext(anomalies)
if !strings.Contains(ctx, "Current Anomalies") {
t.Fatalf("expected anomalies header, got %q", ctx)
}
if !strings.Contains(ctx, "CPU high") {
t.Fatalf("expected anomaly description, got %q", ctx)
}
}
func TestIntelligence_formatAnomalyDescription(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
bl := &baseline.MetricBaseline{Mean: 10}
above := intel.formatAnomalyDescription("cpu", 20, bl, 2.5)
if !strings.Contains(above, "above baseline") {
t.Errorf("expected above-baseline description, got %q", above)
}
below := intel.formatAnomalyDescription("cpu", 5, bl, -1.5)
if !strings.Contains(below, "below baseline") {
t.Errorf("expected below-baseline description, got %q", below)
}
}
func TestIntelligence_FormatContext_Anomalies(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
intel.anomalyDetector = func(resourceID string) []AnomalyReport {
return []AnomalyReport{{Metric: "cpu", Description: "CPU high"}}
}
ctx := intel.FormatContext("res-1")
if !strings.Contains(ctx, "Current Anomalies") {
t.Fatalf("expected anomalies context, got %q", ctx)
}
if !strings.Contains(ctx, "CPU high") {
t.Fatalf("expected anomaly description, got %q", ctx)
}
}
func TestIntelligence_getUpcomingRisks(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
predictions := []patterns.FailurePrediction{
{ResourceID: "r1", EventType: patterns.EventHighCPU, DaysUntil: 2, Confidence: 0.6},
{ResourceID: "r2", EventType: patterns.EventHighMemory, DaysUntil: 9, Confidence: 0.9},
{ResourceID: "r3", EventType: patterns.EventDiskFull, DaysUntil: 4, Confidence: 0.4},
{ResourceID: "r4", EventType: patterns.EventOOM, DaysUntil: 1, Confidence: 0.8},
{ResourceID: "r5", EventType: patterns.EventRestart, DaysUntil: 5, Confidence: 0.7},
}
risk := intel.getUpcomingRisks(predictions, 2)
if len(risk) != 2 {
t.Fatalf("expected 2 risks, got %d", len(risk))
}
if risk[0].ResourceID != "r4" || risk[1].ResourceID != "r1" {
t.Errorf("unexpected ordering: %v", []string{risk[0].ResourceID, risk[1].ResourceID})
}
}
func TestIntelligence_countFindings(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
findings := []*Finding{
nil,
{Severity: FindingSeverityCritical},
{Severity: FindingSeverityWarning},
{Severity: FindingSeverityWatch},
{Severity: FindingSeverityInfo},
}
counts := intel.countFindings(findings)
if counts.Total != 4 {
t.Errorf("expected total 4, got %d", counts.Total)
}
if counts.Critical != 1 || counts.Warning != 1 || counts.Watch != 1 || counts.Info != 1 {
t.Errorf("unexpected counts: %+v", counts)
}
}
func TestIntelligence_getTopFindings(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
base := time.Now()
findings := []*Finding{
{Severity: FindingSeverityWarning, DetectedAt: base.Add(-2 * time.Hour)},
{Severity: FindingSeverityCritical, DetectedAt: base.Add(-3 * time.Hour), Title: "older critical"},
{Severity: FindingSeverityCritical, DetectedAt: base.Add(-1 * time.Hour), Title: "newer critical"},
{Severity: FindingSeverityWatch, DetectedAt: base.Add(-30 * time.Minute)},
}
top := intel.getTopFindings(findings, 3)
if len(top) != 3 {
t.Fatalf("expected 3 findings, got %d", len(top))
}
if top[0].Title != "newer critical" || top[1].Title != "older critical" {
t.Errorf("unexpected ordering: %q, %q", top[0].Title, top[1].Title)
}
if top[2].Severity != FindingSeverityWarning {
t.Errorf("expected warning in third position, got %s", top[2].Severity)
}
}
func TestIntelligence_getTopFindings_Empty(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
if got := intel.getTopFindings(nil, 5); got != nil {
t.Errorf("expected nil for empty findings, got %v", got)
}
}
func TestIntelligence_getLearningStats(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
knowledgeStore, err := knowledge.NewStore(t.TempDir())
if err != nil {
t.Fatalf("NewStore: %v", err)
}
knowledgeStore.SaveNote("vm-1", "vm-1", "vm", "general", "Note", "Content")
knowledgeStore.SaveNote("vm-2", "vm-2", "vm", "general", "Note", "Content")
baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
baselineStore.Learn("vm-1", "vm", "cpu", []baseline.MetricPoint{{Value: 10}})
patternDetector := patterns.NewDetector(patterns.DetectorConfig{
MinOccurrences: 2,
PatternWindow: 48 * time.Hour,
PredictionLimit: 30 * 24 * time.Hour,
})
patternStart := time.Now().Add(-90 * time.Minute)
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-1", EventType: patterns.EventHighCPU, Timestamp: patternStart})
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-1", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
correlationDetector := correlation.NewDetector(correlation.Config{
MinOccurrences: 1,
CorrelationWindow: 2 * time.Hour,
RetentionWindow: 24 * time.Hour,
})
corrStart := time.Now().Add(-30 * time.Minute)
for i := 0; i < 2; i++ {
base := corrStart.Add(time.Duration(i) * 10 * time.Minute)
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-a", ResourceName: "node-a", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: base})
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-b", ResourceName: "vm-b", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: base.Add(1 * time.Minute)})
}
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 5})
intel.SetSubsystems(nil, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil)
stats := intel.getLearningStats()
if stats.ResourcesWithKnowledge != 2 {
t.Errorf("expected 2 resources with knowledge, got %d", stats.ResourcesWithKnowledge)
}
if stats.TotalNotes != 2 {
t.Errorf("expected 2 total notes, got %d", stats.TotalNotes)
}
if stats.ResourcesWithBaselines != 1 {
t.Errorf("expected 1 resource with baseline, got %d", stats.ResourcesWithBaselines)
}
if stats.PatternsDetected != 1 {
t.Errorf("expected 1 pattern, got %d", stats.PatternsDetected)
}
if stats.CorrelationsLearned != 1 {
t.Errorf("expected 1 correlation, got %d", stats.CorrelationsLearned)
}
}
func TestIntelligence_getResourcesAtRisk(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
findings := NewFindingsStore()
findings.Add(&Finding{
ID: "crit-1",
Key: "crit-1",
Severity: FindingSeverityCritical,
Category: FindingCategoryReliability,
ResourceID: "res-critical",
ResourceName: "critical-vm",
ResourceType: "vm",
Title: "Critical outage",
DetectedAt: time.Now(),
LastSeenAt: time.Now(),
Source: "test",
})
findings.Add(&Finding{
ID: "warn-1",
Key: "warn-1",
Severity: FindingSeverityWarning,
Category: FindingCategoryPerformance,
ResourceID: "res-warning",
ResourceName: "warning-vm",
ResourceType: "vm",
Title: "Warning issue",
DetectedAt: time.Now(),
LastSeenAt: time.Now(),
Source: "test",
})
findings.Add(&Finding{
ID: "watch-1",
Key: "watch-1",
Severity: FindingSeverityWatch,
Category: FindingCategoryPerformance,
ResourceID: "res-warning",
ResourceName: "warning-vm",
ResourceType: "vm",
Title: "Watch issue",
DetectedAt: time.Now(),
LastSeenAt: time.Now(),
Source: "test",
})
findings.Add(&Finding{
ID: "info-1",
Key: "info-1",
Severity: FindingSeverityInfo,
Category: FindingCategoryPerformance,
ResourceID: "res-warning",
ResourceName: "warning-vm",
ResourceType: "vm",
Title: "Info issue",
DetectedAt: time.Now(),
LastSeenAt: time.Now(),
Source: "test",
})
intel.SetSubsystems(findings, nil, nil, nil, nil, nil, nil, nil)
risks := intel.getResourcesAtRisk(1)
if len(risks) != 1 {
t.Fatalf("expected 1 risk, got %d", len(risks))
}
if risks[0].ResourceID != "res-critical" {
t.Errorf("expected critical resource first, got %s", risks[0].ResourceID)
}
if risks[0].TopIssue != "Critical outage" {
t.Errorf("expected top issue to be critical, got %s", risks[0].TopIssue)
}
}
func TestIntelligence_calculateResourceHealth_ClampAndSeverities(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
resourceIntel := &ResourceIntelligence{
ResourceID: "test-vm",
ActiveFindings: []*Finding{
nil,
{Severity: FindingSeverityCritical, Title: "crit-1"},
{Severity: FindingSeverityCritical, Title: "crit-2"},
{Severity: FindingSeverityCritical, Title: "crit-3"},
{Severity: FindingSeverityCritical, Title: "crit-4"},
{Severity: FindingSeverityWatch, Title: "watch"},
{Severity: FindingSeverityInfo, Title: "info"},
},
Anomalies: []AnomalyReport{
{Metric: "cpu", Severity: baseline.AnomalyHigh, Description: "high"},
{Metric: "disk", Severity: baseline.AnomalyLow, Description: "low"},
},
}
health := intel.calculateResourceHealth(resourceIntel)
if health.Score != 0 {
t.Errorf("expected score clamped to 0, got %f", health.Score)
}
}
func TestIntelligence_calculateOverallHealth_Clamps(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
var predictions []patterns.FailurePrediction
for i := 0; i < 10; i++ {
predictions = append(predictions, patterns.FailurePrediction{EventType: patterns.EventHighCPU, DaysUntil: 1, Confidence: 0.9})
}
negative := intel.calculateOverallHealth(&IntelligenceSummary{
FindingsCount: FindingsCounts{Critical: 10, Warning: 10},
UpcomingRisks: predictions,
})
if negative.Score != 0 {
t.Errorf("expected score clamped to 0, got %f", negative.Score)
}
positive := intel.calculateOverallHealth(&IntelligenceSummary{
Learning: LearningStats{ResourcesWithKnowledge: 10},
})
if positive.Score != 100 {
t.Errorf("expected score clamped to 100, got %f", positive.Score)
}
if len(positive.Factors) == 0 {
t.Error("expected learning factor for positive health")
}
}
func TestIntelligence_generateHealthPrediction_Branches(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
gradeA := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeA}, &IntelligenceSummary{})
if !strings.Contains(gradeA, "healthy") {
t.Errorf("expected healthy prediction, got %q", gradeA)
}
critical := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeB}, &IntelligenceSummary{
FindingsCount: FindingsCounts{Critical: 2},
})
if !strings.Contains(critical, "Immediate attention") {
t.Errorf("expected critical prediction, got %q", critical)
}
risk := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeB}, &IntelligenceSummary{
UpcomingRisks: []patterns.FailurePrediction{{EventType: patterns.EventHighCPU, DaysUntil: 2, Confidence: 0.8}},
})
if !strings.Contains(risk, "Predicted") {
t.Errorf("expected prediction text, got %q", risk)
}
stable := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeC}, &IntelligenceSummary{})
if !strings.Contains(stable, "stable") {
t.Errorf("expected stable prediction, got %q", stable)
}
}
func TestIntelligence_FormatContext_AllSubsystems(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
knowledgeStore.SaveNote("vm-ctx", "vm-ctx", "vm", "general", "Note", "Content")
baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
baselineStore.Learn("vm-ctx", "vm", "cpu", []baseline.MetricPoint{{Value: 10}})
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
patternStart := time.Now().Add(-90 * time.Minute)
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-ctx", EventType: patterns.EventHighCPU, Timestamp: patternStart})
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-ctx", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour})
corrBase := time.Now().Add(-30 * time.Minute)
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-ctx", ResourceName: "vm-ctx", ResourceType: "vm", EventType: correlation.EventHighCPU, Timestamp: corrBase})
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-ctx", ResourceName: "node-ctx", ResourceType: "node", EventType: correlation.EventRestart, Timestamp: corrBase.Add(1 * time.Minute)})
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10})
incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-ctx", ResourceID: "vm-ctx", ResourceName: "vm-ctx", Type: "cpu", StartTime: time.Now()})
intel.SetSubsystems(nil, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil)
ctx := intel.FormatContext("vm-ctx")
if !strings.Contains(ctx, "Learned Baselines") {
t.Fatalf("expected baseline context, got %q", ctx)
}
if !strings.Contains(ctx, "Failure Predictions") {
t.Fatalf("expected predictions context, got %q", ctx)
}
if !strings.Contains(ctx, "Resource Correlations") {
t.Fatalf("expected correlations context, got %q", ctx)
}
if !strings.Contains(ctx, "Incident Memory") {
t.Fatalf("expected incident context, got %q", ctx)
}
}
func TestIntelligence_FormatGlobalContext_AllSubsystems(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
knowledgeStore.SaveNote("vm-global", "vm-global", "vm", "general", "Note", "Content")
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10})
incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-global", ResourceID: "vm-global", ResourceName: "vm-global", Type: "cpu", StartTime: time.Now()})
correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour})
corrBase := time.Now().Add(-30 * time.Minute)
for i := 0; i < 2; i++ {
base := corrBase.Add(time.Duration(i) * 10 * time.Minute)
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-a", ResourceName: "node-a", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: base})
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-global", ResourceName: "vm-global", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: base.Add(1 * time.Minute)})
}
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
patternStart := time.Now().Add(-90 * time.Minute)
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-global", EventType: patterns.EventHighCPU, Timestamp: patternStart})
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-global", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
intel.SetSubsystems(nil, patternDetector, correlationDetector, nil, incidentStore, knowledgeStore, nil, nil)
ctx := intel.FormatGlobalContext()
if ctx == "" {
t.Fatal("expected non-empty global context")
}
if !strings.Contains(ctx, "Resource Correlations") {
t.Fatalf("expected correlations in global context, got %q", ctx)
}
}
func TestIntelligence_GetSummary_WithSubsystems(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
findings := NewFindingsStore()
findings.Add(&Finding{
ID: "f1",
Key: "f1",
Severity: FindingSeverityWarning,
Category: FindingCategoryPerformance,
ResourceID: "vm-sum",
ResourceName: "vm-sum",
ResourceType: "vm",
Title: "Warning",
DetectedAt: time.Now(),
LastSeenAt: time.Now(),
Source: "test",
})
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
patternStart := time.Now().Add(-90 * time.Minute)
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-sum", EventType: patterns.EventHighCPU, Timestamp: patternStart})
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-sum", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
changes := memory.NewChangeDetector(memory.ChangeDetectorConfig{MaxChanges: 10})
changes.DetectChanges([]memory.ResourceSnapshot{{ID: "vm-sum", Name: "vm-sum", Type: "vm", Status: "running", SnapshotTime: time.Now()}})
remediations := memory.NewRemediationLog(memory.RemediationLogConfig{MaxRecords: 10})
remediations.Log(memory.RemediationRecord{ResourceID: "vm-sum", Problem: "cpu", Action: "restart", Outcome: memory.OutcomeResolved})
intel.SetSubsystems(findings, patternDetector, nil, nil, nil, nil, changes, remediations)
summary := intel.GetSummary()
if summary.FindingsCount.Total == 0 {
t.Error("expected findings in summary")
}
if summary.PredictionsCount == 0 {
t.Error("expected predictions in summary")
}
if len(summary.UpcomingRisks) == 0 {
t.Error("expected upcoming risks in summary")
}
if summary.RecentChangesCount == 0 {
t.Error("expected recent changes in summary")
}
if len(summary.RecentRemediations) == 0 {
t.Error("expected recent remediations in summary")
}
if len(summary.ResourcesAtRisk) == 0 {
t.Error("expected resources at risk in summary")
}
}
func TestIntelligence_GetResourceIntelligence_WithAllSubsystems(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
findings := NewFindingsStore()
findings.Add(&Finding{
ID: "f1",
Key: "f1",
Severity: FindingSeverityWarning,
Category: FindingCategoryPerformance,
ResourceID: "vm-intel",
ResourceName: "vm-intel",
ResourceType: "vm",
Title: "Warning",
DetectedAt: time.Now(),
LastSeenAt: time.Now(),
Source: "test",
})
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
patternStart := time.Now().Add(-90 * time.Minute)
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-intel", EventType: patterns.EventHighCPU, Timestamp: patternStart})
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-intel", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour})
corrBase := time.Now().Add(-30 * time.Minute)
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-intel", ResourceName: "node-intel", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: corrBase})
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-intel", ResourceName: "vm-intel", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: corrBase.Add(1 * time.Minute)})
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-intel", ResourceName: "vm-intel", ResourceType: "vm", EventType: correlation.EventHighCPU, Timestamp: corrBase.Add(2 * time.Minute)})
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-child", ResourceName: "vm-child", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: corrBase.Add(3 * time.Minute)})
baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
baselineStore.Learn("vm-intel", "vm", "cpu", []baseline.MetricPoint{{Value: 10}})
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10})
incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-intel", ResourceID: "vm-intel", ResourceName: "vm-intel", Type: "cpu", StartTime: time.Now()})
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
knowledgeStore.SaveNote("vm-intel", "vm-intel", "vm", "general", "Note", "Content")
intel.SetSubsystems(findings, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil)
res := intel.GetResourceIntelligence("vm-intel")
if len(res.ActiveFindings) == 0 {
t.Error("expected active findings")
}
if len(res.Predictions) == 0 {
t.Error("expected predictions")
}
if len(res.Correlations) == 0 {
t.Error("expected correlations")
}
if len(res.Dependents) == 0 || len(res.Dependencies) == 0 {
t.Error("expected dependencies and dependents")
}
if len(res.Baselines) == 0 {
t.Error("expected baselines")
}
if len(res.RecentIncidents) == 0 {
t.Error("expected incidents")
}
if res.Knowledge == nil || res.NoteCount == 0 {
t.Error("expected knowledge")
}
}
func TestIntelligence_GetResourceIntelligence_KnowledgeFallback(t *testing.T) {
intel := NewIntelligence(IntelligenceConfig{})
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
knowledgeStore.SaveNote("vm-know", "knowledge-vm", "vm", "general", "Note", "Content")
intel.SetSubsystems(nil, nil, nil, nil, nil, knowledgeStore, nil, nil)
res := intel.GetResourceIntelligence("vm-know")
if res.ResourceName != "knowledge-vm" {
t.Errorf("expected resource name from knowledge, got %q", res.ResourceName)
}
if res.ResourceType != "vm" {
t.Errorf("expected resource type from knowledge, got %q", res.ResourceType)
}
}

View File

@@ -167,9 +167,6 @@ func (s *IncidentStore) RecordAlertAcknowledged(alert *alerts.Alert, user string
defer s.mu.Unlock()
incident := s.ensureIncidentForAlertLocked(alert)
if incident == nil {
return
}
incident.Acknowledged = true
if alert.AckTime != nil {
@@ -198,9 +195,6 @@ func (s *IncidentStore) RecordAlertUnacknowledged(alert *alerts.Alert, user stri
defer s.mu.Unlock()
incident := s.ensureIncidentForAlertLocked(alert)
if incident == nil {
return
}
incident.Acknowledged = false
incident.AckTime = nil

File diff suppressed because it is too large Load Diff

View File

@@ -52,8 +52,8 @@ type mockThresholdProvider struct {
storage float64
}
func (m *mockThresholdProvider) GetNodeCPUThreshold() float64 { return m.nodeCPU }
func (m *mockThresholdProvider) GetNodeMemoryThreshold() float64 { return m.nodeMemory }
func (m *mockThresholdProvider) GetNodeCPUThreshold() float64 { return m.nodeCPU }
func (m *mockThresholdProvider) GetNodeMemoryThreshold() float64 { return m.nodeMemory }
func (m *mockThresholdProvider) GetGuestCPUThreshold() float64 { return 0 }
func (m *mockThresholdProvider) GetGuestMemoryThreshold() float64 { return m.guestMem }
func (m *mockThresholdProvider) GetGuestDiskThreshold() float64 { return m.guestDisk }
@@ -61,11 +61,17 @@ func (m *mockThresholdProvider) GetStorageThreshold() float64 { return m.sto
type mockResourceProvider struct {
ResourceProvider
getAllFunc func() []resources.Resource
getStatsFunc func() resources.StoreStats
getSummaryFunc func() resources.ResourceSummary
getAllFunc func() []resources.Resource
getStatsFunc func() resources.StoreStats
getSummaryFunc func() resources.ResourceSummary
getInfrastructureFunc func() []resources.Resource
getWorkloadsFunc func() []resources.Resource
getWorkloadsFunc func() []resources.Resource
getByTypeFunc func(t resources.ResourceType) []resources.Resource
getTopCPUFunc func(limit int, types []resources.ResourceType) []resources.Resource
getTopMemoryFunc func(limit int, types []resources.ResourceType) []resources.Resource
getTopDiskFunc func(limit int, types []resources.ResourceType) []resources.Resource
getRelatedFunc func(resourceID string) map[string][]resources.Resource
findContainerHostFunc func(containerNameOrID string) string
}
func (m *mockResourceProvider) GetAll() []resources.Resource {
@@ -99,23 +105,46 @@ func (m *mockResourceProvider) GetWorkloads() []resources.Resource {
return nil
}
func (m *mockResourceProvider) GetType(t resources.ResourceType) []resources.Resource { return nil }
func (m *mockResourceProvider) GetByType(t resources.ResourceType) []resources.Resource {
if m.getByTypeFunc != nil {
return m.getByTypeFunc(t)
}
return nil
}
func (m *mockResourceProvider) GetTopByCPU(limit int, types []resources.ResourceType) []resources.Resource {
if m.getTopCPUFunc != nil {
return m.getTopCPUFunc(limit, types)
}
return nil
}
func (m *mockResourceProvider) GetTopByMemory(limit int, types []resources.ResourceType) []resources.Resource {
if m.getTopMemoryFunc != nil {
return m.getTopMemoryFunc(limit, types)
}
return nil
}
func (m *mockResourceProvider) GetTopByDisk(limit int, types []resources.ResourceType) []resources.Resource {
if m.getTopDiskFunc != nil {
return m.getTopDiskFunc(limit, types)
}
return nil
}
func (m *mockResourceProvider) GetRelated(resourceID string) map[string][]resources.Resource {
if m.getRelatedFunc != nil {
return m.getRelatedFunc(resourceID)
}
return nil
}
func (m *mockResourceProvider) FindContainerHost(containerNameOrID string) string { return "" }
func (m *mockResourceProvider) FindContainerHost(containerNameOrID string) string {
if m.findContainerHostFunc != nil {
return m.findContainerHostFunc(containerNameOrID)
}
return ""
}
type mockAgentServer struct {
agents []agentexec.ConnectedAgent
executeFunc func(ctx context.Context, agentID string, cmd agentexec.ExecuteCommandPayload) (*agentexec.CommandResultPayload, error)
agents []agentexec.ConnectedAgent
executeFunc func(ctx context.Context, agentID string, cmd agentexec.ExecuteCommandPayload) (*agentexec.CommandResultPayload, error)
}
func (m *mockAgentServer) GetConnectedAgents() []agentexec.ConnectedAgent {

View File

@@ -284,10 +284,6 @@ func (s *Service) buildUnifiedResourceContext() string {
}
}
if len(sections) == 0 {
return ""
}
result := "\n\n" + strings.Join(sections, "\n")
// Limit context size

View File

@@ -0,0 +1,263 @@
package ai
import (
"strings"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/agentexec"
"github.com/rcourtman/pulse-go-rewrite/internal/resources"
)
func TestBuildUnifiedResourceContext_NilProvider(t *testing.T) {
s := &Service{}
if got := s.buildUnifiedResourceContext(); got != "" {
t.Errorf("expected empty context, got %q", got)
}
}
func TestBuildUnifiedResourceContext_FullContext(t *testing.T) {
nodeWithAgent := resources.Resource{
ID: "node-1",
Name: "delly",
Type: resources.ResourceTypeNode,
PlatformType: resources.PlatformProxmoxPVE,
ClusterID: "cluster-a",
Status: resources.StatusOnline,
CPU: metricValue(12.3),
Memory: metricValue(45.6),
}
nodeNoAgent := resources.Resource{
ID: "node-2",
Name: "minipc",
Type: resources.ResourceTypeNode,
PlatformType: resources.PlatformProxmoxPVE,
Status: resources.StatusDegraded,
}
dockerNode := resources.Resource{
ID: "dock-node",
Name: "dock-node",
Type: resources.ResourceTypeNode,
PlatformType: resources.PlatformDocker,
Status: resources.StatusOnline,
}
host := resources.Resource{
ID: "host-1",
Name: "barehost",
Type: resources.ResourceTypeHost,
PlatformType: resources.PlatformHostAgent,
Status: resources.StatusOnline,
Identity: &resources.ResourceIdentity{IPs: []string{"192.168.1.10"}},
CPU: metricValue(5.0),
Memory: metricValue(10.0),
}
dockerHost := resources.Resource{
ID: "docker-1",
Name: "dockhost",
Type: resources.ResourceTypeDockerHost,
PlatformType: resources.PlatformDocker,
Status: resources.StatusRunning,
}
vm := resources.Resource{
ID: "vm-100",
Name: "web-vm",
Type: resources.ResourceTypeVM,
ParentID: "node-1",
PlatformID: "100",
Status: resources.StatusRunning,
Identity: &resources.ResourceIdentity{IPs: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}},
CPU: metricValue(65.4),
Memory: metricValue(70.2),
}
vm.Alerts = []resources.ResourceAlert{
{ID: "alert-1", Message: "CPU high", Level: "critical"},
}
ct := resources.Resource{
ID: "ct-200",
Name: "db-ct",
Type: resources.ResourceTypeContainer,
ParentID: "node-1",
PlatformID: "200",
Status: resources.StatusStopped,
}
dockerContainer := resources.Resource{
ID: "dock-300",
Name: "redis",
Type: resources.ResourceTypeDockerContainer,
ParentID: "docker-1",
Status: resources.StatusRunning,
Disk: metricValue(70.0),
}
dockerStopped := resources.Resource{
ID: "dock-301",
Name: "cache",
Type: resources.ResourceTypeDockerContainer,
ParentID: "docker-1",
Status: resources.StatusStopped,
}
unknownParent := resources.Resource{
ID: "vm-999",
Name: "mystery",
Type: resources.ResourceTypeVM,
ParentID: "unknown-parent",
PlatformID: "999",
Status: resources.StatusRunning,
}
orphan := resources.Resource{
ID: "orphan-1",
Name: "orphan",
Type: resources.ResourceTypeContainer,
Status: resources.StatusRunning,
Identity: &resources.ResourceIdentity{IPs: []string{"172.16.0.5"}},
}
infrastructure := []resources.Resource{nodeWithAgent, nodeNoAgent, host, dockerHost, dockerNode}
workloads := []resources.Resource{vm, ct, dockerContainer, dockerStopped, unknownParent, orphan}
all := append(append([]resources.Resource{}, infrastructure...), workloads...)
stats := resources.StoreStats{
TotalResources: len(all),
ByType: map[resources.ResourceType]int{
resources.ResourceTypeNode: 3,
resources.ResourceTypeHost: 1,
resources.ResourceTypeDockerHost: 1,
resources.ResourceTypeVM: 2,
resources.ResourceTypeContainer: 2,
resources.ResourceTypeDockerContainer: 2,
},
}
summary := resources.ResourceSummary{
TotalResources: len(all),
Healthy: 6,
Degraded: 1,
Offline: 4,
WithAlerts: 1,
ByType: map[resources.ResourceType]resources.TypeSummary{
resources.ResourceTypeVM: {Count: 2, AvgCPUPercent: 40.0, AvgMemoryPercent: 55.0},
resources.ResourceTypeContainer: {Count: 2},
},
}
mockRP := &mockResourceProvider{
getStatsFunc: func() resources.StoreStats {
return stats
},
getInfrastructureFunc: func() []resources.Resource {
return infrastructure
},
getWorkloadsFunc: func() []resources.Resource {
return workloads
},
getAllFunc: func() []resources.Resource {
return all
},
getSummaryFunc: func() resources.ResourceSummary {
return summary
},
getTopCPUFunc: func(limit int, types []resources.ResourceType) []resources.Resource {
return []resources.Resource{vm}
},
getTopMemoryFunc: func(limit int, types []resources.ResourceType) []resources.Resource {
return []resources.Resource{host}
},
getTopDiskFunc: func(limit int, types []resources.ResourceType) []resources.Resource {
return []resources.Resource{dockerContainer}
},
}
s := &Service{resourceProvider: mockRP}
s.agentServer = &mockAgentServer{
agents: []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "delly"},
},
}
got := s.buildUnifiedResourceContext()
if got == "" {
t.Fatal("expected non-empty context")
}
assertContains := func(substr string) {
t.Helper()
if !strings.Contains(got, substr) {
t.Fatalf("expected context to contain %q", substr)
}
}
assertContains("## Unified Infrastructure View")
assertContains("Total resources: 11 (Infrastructure: 5, Workloads: 6)")
assertContains("Proxmox VE Nodes")
assertContains("HAS AGENT")
assertContains("NO AGENT")
assertContains("cluster: cluster-a")
assertContains("Standalone Hosts")
assertContains("192.168.1.10")
assertContains("Docker/Podman Hosts")
assertContains("1/2 containers running")
assertContains("Workloads (VMs & Containers)")
assertContains("On delly")
assertContains("On unknown-parent")
assertContains("Other workloads")
assertContains("10.0.0.1, 10.0.0.2")
assertContains("Resources with Active Alerts")
assertContains("CPU high")
assertContains("Infrastructure Summary")
assertContains("Resources with alerts: 1")
assertContains("Average utilization by type")
assertContains("Top CPU Consumers")
assertContains("Top Memory Consumers")
assertContains("Top Disk Usage")
}
func TestBuildUnifiedResourceContext_TruncatesLargeContext(t *testing.T) {
largeName := strings.Repeat("a", 60000)
node := resources.Resource{
ID: "node-1",
Name: "node-1",
DisplayName: largeName,
Type: resources.ResourceTypeNode,
PlatformType: resources.PlatformProxmoxPVE,
Status: resources.StatusOnline,
}
stats := resources.StoreStats{
TotalResources: 1,
ByType: map[resources.ResourceType]int{
resources.ResourceTypeNode: 1,
},
}
mockRP := &mockResourceProvider{
getStatsFunc: func() resources.StoreStats {
return stats
},
getInfrastructureFunc: func() []resources.Resource {
return []resources.Resource{node}
},
getWorkloadsFunc: func() []resources.Resource {
return nil
},
getAllFunc: func() []resources.Resource {
return []resources.Resource{node}
},
getSummaryFunc: func() resources.ResourceSummary {
return resources.ResourceSummary{TotalResources: 1}
},
}
s := &Service{resourceProvider: mockRP}
got := s.buildUnifiedResourceContext()
if !strings.Contains(got, "[... Context truncated ...]") {
t.Fatal("expected context to be truncated")
}
if len(got) <= 50000 {
t.Fatalf("expected truncated context length > 50000, got %d", len(got))
}
}
func metricValue(current float64) *resources.MetricValue {
return &resources.MetricValue{Current: current}
}

View File

@@ -1,9 +1,14 @@
package ai
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/agentexec"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rcourtman/pulse-go-rewrite/internal/models"
)
func TestExtractVMIDFromTargetID(t *testing.T) {
@@ -16,21 +21,21 @@ func TestExtractVMIDFromTargetID(t *testing.T) {
{"plain vmid", "106", 106},
{"node-vmid", "minipc-106", 106},
{"instance-node-vmid", "delly-minipc-106", 106},
// Edge cases with hyphenated names
{"hyphenated-node-vmid", "pve-node-01-106", 106},
{"hyphenated-instance-node-vmid", "my-cluster-pve-node-01-106", 106},
// Type prefixes
{"lxc prefix", "lxc-106", 106},
{"vm prefix", "vm-106", 106},
{"ct prefix", "ct-106", 106},
// Non-numeric - should return 0
{"non-numeric", "mycontainer", 0},
{"no-vmid", "node-name", 0},
{"empty", "", 0},
// Large VMIDs (Proxmox uses up to 999999999)
{"large vmid", "node-999999", 999999},
}
@@ -53,18 +58,18 @@ func TestRoutingError(t *testing.T) {
Reason: "No agent connected to node \"minipc\"",
Suggestion: "Install pulse-agent on minipc",
}
want := "No agent connected to node \"minipc\". Install pulse-agent on minipc"
if err.Error() != want {
t.Errorf("Error() = %q, want %q", err.Error(), want)
}
})
t.Run("without suggestion", func(t *testing.T) {
err := &RoutingError{
Reason: "No agents connected",
}
want := "No agents connected"
if err.Error() != want {
t.Errorf("Error() = %q, want %q", err.Error(), want)
@@ -74,22 +79,22 @@ func TestRoutingError(t *testing.T) {
func TestRouteToAgent_NoAgents(t *testing.T) {
s := &Service{}
req := ExecuteRequest{
TargetType: "container",
TargetID: "minipc-106",
}
_, err := s.routeToAgent(req, "pct exec 106 -- hostname", nil)
if err == nil {
t.Error("expected error for no agents, got nil")
}
routingErr, ok := err.(*RoutingError)
if !ok {
t.Fatalf("expected RoutingError, got %T", err)
}
if routingErr.Suggestion == "" {
t.Error("expected suggestion in error")
}
@@ -97,19 +102,19 @@ func TestRouteToAgent_NoAgents(t *testing.T) {
func TestRouteToAgent_ExactMatch(t *testing.T) {
s := &Service{}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "delly"},
{AgentID: "agent-2", Hostname: "minipc"},
{AgentID: "agent-3", Hostname: "pimox"},
}
tests := []struct {
name string
req ExecuteRequest
command string
wantAgentID string
wantHostname string
name string
req ExecuteRequest
command string
wantAgentID string
wantHostname string
}{
{
name: "route by context node",
@@ -151,11 +156,11 @@ func TestRouteToAgent_ExactMatch(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.AgentID != tt.wantAgentID {
t.Errorf("AgentID = %q, want %q", result.AgentID, tt.wantAgentID)
}
if result.AgentHostname != tt.wantHostname {
t.Errorf("AgentHostname = %q, want %q", result.AgentHostname, tt.wantHostname)
}
@@ -165,28 +170,28 @@ func TestRouteToAgent_ExactMatch(t *testing.T) {
func TestRouteToAgent_NoSubstringMatching(t *testing.T) {
s := &Service{}
// Agent "mini" should NOT match node "minipc"
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "mini"},
{AgentID: "agent-2", Hostname: "pc"},
}
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{"node": "minipc"},
}
_, err := s.routeToAgent(req, "hostname", agents)
if err == nil {
t.Error("expected error when no exact match, got nil (substring matching may be occurring)")
}
routingErr, ok := err.(*RoutingError)
if !ok {
t.Fatalf("expected RoutingError, got %T", err)
}
if routingErr.TargetNode != "minipc" {
t.Errorf("TargetNode = %q, want %q", routingErr.TargetNode, "minipc")
}
@@ -194,21 +199,21 @@ func TestRouteToAgent_NoSubstringMatching(t *testing.T) {
func TestRouteToAgent_CaseInsensitive(t *testing.T) {
s := &Service{}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "MiniPC"},
}
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{"node": "minipc"}, // lowercase
}
result, err := s.routeToAgent(req, "hostname", agents)
if err != nil {
t.Fatalf("expected case-insensitive match, got error: %v", err)
}
if result.AgentID != "agent-1" {
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1")
}
@@ -216,22 +221,22 @@ func TestRouteToAgent_CaseInsensitive(t *testing.T) {
func TestRouteToAgent_HyphenatedNodeNames(t *testing.T) {
s := &Service{}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "pve-node-01"},
{AgentID: "agent-2", Hostname: "pve-node-02"},
}
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{"node": "pve-node-02"},
}
result, err := s.routeToAgent(req, "hostname", agents)
if err != nil {
t.Fatalf("unexpected error for hyphenated node names: %v", err)
}
if result.AgentID != "agent-2" {
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-2")
}
@@ -239,38 +244,441 @@ func TestRouteToAgent_HyphenatedNodeNames(t *testing.T) {
func TestRouteToAgent_ActionableErrorMessages(t *testing.T) {
s := &Service{}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "delly"},
}
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{"node": "minipc"},
}
_, err := s.routeToAgent(req, "hostname", agents)
if err == nil {
t.Fatal("expected error, got nil")
}
routingErr, ok := err.(*RoutingError)
if !ok {
t.Fatalf("expected RoutingError, got %T", err)
}
// Error should mention the target node
if routingErr.TargetNode != "minipc" {
t.Errorf("TargetNode = %q, want %q", routingErr.TargetNode, "minipc")
}
// Error should list available agents
if len(routingErr.AvailableAgents) == 0 {
t.Error("expected available agents in error")
}
// Error should have actionable suggestion
if routingErr.Suggestion == "" {
t.Error("expected suggestion in error message")
}
}
func TestRoutingError_ForAI(t *testing.T) {
t.Run("clarification", func(t *testing.T) {
err := &RoutingError{
Reason: "Cannot determine which host should execute this command",
AvailableAgents: []string{"delly", "pimox"},
AskForClarification: true,
}
msg := err.ForAI()
if !strings.Contains(msg, "ROUTING_CLARIFICATION_NEEDED") {
t.Errorf("expected clarification marker, got %q", msg)
}
if !strings.Contains(msg, "delly, pimox") {
t.Errorf("expected available hosts list, got %q", msg)
}
if !strings.Contains(msg, err.Reason) {
t.Errorf("expected reason to be included, got %q", msg)
}
})
t.Run("fallback", func(t *testing.T) {
err := &RoutingError{
Reason: "No agents connected",
AskForClarification: true,
}
if err.ForAI() != err.Error() {
t.Errorf("expected ForAI to fall back to Error, got %q", err.ForAI())
}
})
}
func TestRouteToAgent_VMIDRoutingWithInstance(t *testing.T) {
stateProvider := &mockStateProvider{
state: models.StateSnapshot{
Containers: []models.Container{
{VMID: 106, Node: "node-b", Name: "ct-b", Instance: "cluster-b"},
{VMID: 106, Node: "node-a", Name: "ct-a", Instance: "cluster-b"},
},
},
}
s := &Service{stateProvider: stateProvider}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-a", Hostname: "node-a"},
{AgentID: "agent-b", Hostname: "node-b"},
}
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{"instance": "cluster-b"},
}
result, err := s.routeToAgent(req, "pct exec 106 -- hostname", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.AgentID != "agent-b" {
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-b")
}
if result.RoutingMethod != "vmid_lookup_with_instance" {
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "vmid_lookup_with_instance")
}
}
func TestRouteToAgent_VMIDCollision(t *testing.T) {
stateProvider := &mockStateProvider{
state: models.StateSnapshot{
VMs: []models.VM{
{VMID: 200, Node: "node-a", Name: "vm-a", Instance: "cluster-a"},
{VMID: 200, Node: "node-b", Name: "vm-b", Instance: "cluster-b"},
},
},
}
s := &Service{stateProvider: stateProvider}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-a", Hostname: "node-a"},
{AgentID: "agent-b", Hostname: "node-b"},
}
req := ExecuteRequest{
TargetType: "vm",
}
_, err := s.routeToAgent(req, "pct exec 200 -- hostname", agents)
if err == nil {
t.Fatal("expected error for VMID collision, got nil")
}
routingErr, ok := err.(*RoutingError)
if !ok {
t.Fatalf("expected RoutingError, got %T", err)
}
if routingErr.TargetVMID != 200 {
t.Errorf("TargetVMID = %d, want %d", routingErr.TargetVMID, 200)
}
if !strings.Contains(routingErr.Reason, "exists on multiple nodes") {
t.Errorf("expected collision reason, got %q", routingErr.Reason)
}
}
func TestRouteToAgent_VMIDNotFoundFallsBackToContext(t *testing.T) {
stateProvider := &mockStateProvider{}
s := &Service{stateProvider: stateProvider}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "minipc"},
}
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{"node": "minipc"},
}
result, err := s.routeToAgent(req, "pct exec 106 -- hostname", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.RoutingMethod != "context_node" {
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "context_node")
}
if len(result.Warnings) != 1 || !strings.Contains(result.Warnings[0], "VMID 106 not found") {
t.Errorf("expected warning about missing VMID, got %v", result.Warnings)
}
}
func TestRouteToAgent_ResourceProviderContexts(t *testing.T) {
s := &Service{}
mockRP := &mockResourceProvider{
findContainerHostFunc: func(containerNameOrID string) string {
if containerNameOrID == "" {
return ""
}
return "rp-host"
},
}
s.resourceProvider = mockRP
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "rp-host"},
}
tests := []struct {
name string
key string
}{
{name: "containerName", key: "containerName"},
{name: "name", key: "name"},
{name: "guestName", key: "guestName"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := ExecuteRequest{
TargetType: "container",
Context: map[string]interface{}{tt.key: "workload"},
}
result, err := s.routeToAgent(req, "docker ps", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.RoutingMethod != "resource_provider_lookup" {
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "resource_provider_lookup")
}
})
}
}
func TestRouteToAgent_TargetIDLookup(t *testing.T) {
stateProvider := &mockStateProvider{
state: models.StateSnapshot{
VMs: []models.VM{
{VMID: 222, Node: "node-vm", Name: "vm-222", Instance: "cluster-a"},
},
},
}
s := &Service{stateProvider: stateProvider}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "node-vm"},
}
req := ExecuteRequest{
TargetType: "vm",
TargetID: "vm-222",
Context: map[string]interface{}{"instance": "cluster-a"},
}
result, err := s.routeToAgent(req, "status", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.RoutingMethod != "target_id_vmid_lookup" {
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "target_id_vmid_lookup")
}
if result.AgentID != "agent-1" {
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1")
}
}
func TestRouteToAgent_VMIDLookupSingleMatch(t *testing.T) {
stateProvider := &mockStateProvider{
state: models.StateSnapshot{
VMs: []models.VM{
{VMID: 101, Node: "node-1", Name: "vm-one"},
},
},
}
s := &Service{stateProvider: stateProvider}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "node-1"},
}
req := ExecuteRequest{
TargetType: "vm",
}
result, err := s.routeToAgent(req, "qm start 101", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.RoutingMethod != "vmid_lookup" {
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "vmid_lookup")
}
if result.AgentID != "agent-1" {
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1")
}
}
func TestRouteToAgent_ClusterPeer(t *testing.T) {
tmp := t.TempDir()
persistence := config.NewConfigPersistence(tmp)
err := persistence.SaveNodesConfig([]config.PVEInstance{
{
Name: "node-a",
IsCluster: true,
ClusterName: "cluster-a",
ClusterEndpoints: []config.ClusterEndpoint{
{NodeName: "node-a"},
{NodeName: "node-b"},
},
},
}, nil, nil)
if err != nil {
t.Fatalf("SaveNodesConfig: %v", err)
}
s := &Service{persistence: persistence}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-b", Hostname: "node-b"},
}
req := ExecuteRequest{
TargetType: "vm",
Context: map[string]interface{}{"node": "node-a"},
}
result, err := s.routeToAgent(req, "hostname", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !result.ClusterPeer {
t.Error("expected cluster peer routing")
}
if result.AgentID != "agent-b" {
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-b")
}
}
func TestRouteToAgent_SingleAgentFallbackForHost(t *testing.T) {
s := &Service{}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "only"},
}
req := ExecuteRequest{
TargetType: "host",
}
result, err := s.routeToAgent(req, "uptime", agents)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.RoutingMethod != "single_agent_fallback" {
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "single_agent_fallback")
}
if len(result.Warnings) != 1 {
t.Errorf("expected warning for single agent fallback, got %v", result.Warnings)
}
}
func TestRouteToAgent_AsksForClarification(t *testing.T) {
s := &Service{}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "delly"},
{AgentID: "agent-2", Hostname: "pimox"},
}
req := ExecuteRequest{
TargetType: "vm",
}
_, err := s.routeToAgent(req, "hostname", agents)
if err == nil {
t.Fatal("expected error, got nil")
}
routingErr, ok := err.(*RoutingError)
if !ok {
t.Fatalf("expected RoutingError, got %T", err)
}
if !routingErr.AskForClarification {
t.Error("expected AskForClarification to be true")
}
if len(routingErr.AvailableAgents) != 2 {
t.Errorf("expected available agents, got %v", routingErr.AvailableAgents)
}
}
func TestFindClusterPeerAgent_NoPersistence(t *testing.T) {
s := &Service{}
if got := s.findClusterPeerAgent("node-a", nil); got != "" {
t.Errorf("expected empty result, got %q", got)
}
}
func TestFindClusterPeerAgent_LoadError(t *testing.T) {
tmp := t.TempDir()
persistence := config.NewConfigPersistence(tmp)
nodesPath := filepath.Join(tmp, "nodes.enc")
if err := os.Mkdir(nodesPath, 0700); err != nil {
t.Fatalf("Mkdir nodes.enc: %v", err)
}
s := &Service{persistence: persistence}
if got := s.findClusterPeerAgent("node-a", nil); got != "" {
t.Errorf("expected empty result, got %q", got)
}
}
func TestFindClusterPeerAgent_NotCluster(t *testing.T) {
tmp := t.TempDir()
persistence := config.NewConfigPersistence(tmp)
err := persistence.SaveNodesConfig([]config.PVEInstance{
{Name: "node-a"},
}, nil, nil)
if err != nil {
t.Fatalf("SaveNodesConfig: %v", err)
}
s := &Service{persistence: persistence}
if got := s.findClusterPeerAgent("node-a", nil); got != "" {
t.Errorf("expected empty result, got %q", got)
}
}
func TestFindClusterPeerAgent_NoAgentMatch(t *testing.T) {
tmp := t.TempDir()
persistence := config.NewConfigPersistence(tmp)
err := persistence.SaveNodesConfig([]config.PVEInstance{
{
Name: "node-a",
IsCluster: true,
ClusterName: "cluster-a",
ClusterEndpoints: []config.ClusterEndpoint{
{NodeName: "node-a"},
{NodeName: "node-b"},
},
},
}, nil, nil)
if err != nil {
t.Fatalf("SaveNodesConfig: %v", err)
}
s := &Service{persistence: persistence}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "node-c"},
}
if got := s.findClusterPeerAgent("node-a", agents); got != "" {
t.Errorf("expected empty result, got %q", got)
}
}
func TestFindClusterPeerAgent_EndpointMatch(t *testing.T) {
tmp := t.TempDir()
persistence := config.NewConfigPersistence(tmp)
err := persistence.SaveNodesConfig([]config.PVEInstance{
{
Name: "cluster-master",
IsCluster: true,
ClusterName: "cluster-a",
ClusterEndpoints: []config.ClusterEndpoint{
{NodeName: "node-a"},
{NodeName: "node-b"},
},
},
}, nil, nil)
if err != nil {
t.Fatalf("SaveNodesConfig: %v", err)
}
s := &Service{persistence: persistence}
agents := []agentexec.ConnectedAgent{
{AgentID: "agent-b", Hostname: "node-b"},
}
if got := s.findClusterPeerAgent("node-a", agents); got != "agent-b" {
t.Errorf("expected peer agent, got %q", got)
}
}

View File

@@ -12,24 +12,33 @@ import (
"time"
)
var commandRunner = func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
return stdout.Bytes(), stderr.Bytes(), err
}
// ClusterStatus represents the complete Ceph cluster status as collected by the agent.
type ClusterStatus struct {
FSID string `json:"fsid"`
Health HealthStatus `json:"health"`
MonMap MonitorMap `json:"monMap,omitempty"`
MgrMap ManagerMap `json:"mgrMap,omitempty"`
OSDMap OSDMap `json:"osdMap"`
PGMap PGMap `json:"pgMap"`
Pools []Pool `json:"pools,omitempty"`
Services []ServiceInfo `json:"services,omitempty"`
CollectedAt time.Time `json:"collectedAt"`
FSID string `json:"fsid"`
Health HealthStatus `json:"health"`
MonMap MonitorMap `json:"monMap,omitempty"`
MgrMap ManagerMap `json:"mgrMap,omitempty"`
OSDMap OSDMap `json:"osdMap"`
PGMap PGMap `json:"pgMap"`
Pools []Pool `json:"pools,omitempty"`
Services []ServiceInfo `json:"services,omitempty"`
CollectedAt time.Time `json:"collectedAt"`
}
// HealthStatus represents Ceph cluster health.
type HealthStatus struct {
Status string `json:"status"` // HEALTH_OK, HEALTH_WARN, HEALTH_ERR
Checks map[string]Check `json:"checks,omitempty"`
Summary []HealthSummary `json:"summary,omitempty"`
Status string `json:"status"` // HEALTH_OK, HEALTH_WARN, HEALTH_ERR
Checks map[string]Check `json:"checks,omitempty"`
Summary []HealthSummary `json:"summary,omitempty"`
}
// Check represents a health check detail.
@@ -70,28 +79,28 @@ type ManagerMap struct {
// OSDMap represents OSD status summary.
type OSDMap struct {
Epoch int `json:"epoch"`
NumOSDs int `json:"numOsds"`
NumUp int `json:"numUp"`
NumIn int `json:"numIn"`
NumDown int `json:"numDown,omitempty"`
NumOut int `json:"numOut,omitempty"`
Epoch int `json:"epoch"`
NumOSDs int `json:"numOsds"`
NumUp int `json:"numUp"`
NumIn int `json:"numIn"`
NumDown int `json:"numDown,omitempty"`
NumOut int `json:"numOut,omitempty"`
}
// PGMap represents placement group statistics.
type PGMap struct {
NumPGs int `json:"numPgs"`
BytesTotal uint64 `json:"bytesTotal"`
BytesUsed uint64 `json:"bytesUsed"`
BytesAvailable uint64 `json:"bytesAvailable"`
DataBytes uint64 `json:"dataBytes,omitempty"`
UsagePercent float64 `json:"usagePercent"`
DegradedRatio float64 `json:"degradedRatio,omitempty"`
MisplacedRatio float64 `json:"misplacedRatio,omitempty"`
ReadBytesPerSec uint64 `json:"readBytesPerSec,omitempty"`
WriteBytesPerSec uint64 `json:"writeBytesPerSec,omitempty"`
ReadOpsPerSec uint64 `json:"readOpsPerSec,omitempty"`
WriteOpsPerSec uint64 `json:"writeOpsPerSec,omitempty"`
NumPGs int `json:"numPgs"`
BytesTotal uint64 `json:"bytesTotal"`
BytesUsed uint64 `json:"bytesUsed"`
BytesAvailable uint64 `json:"bytesAvailable"`
DataBytes uint64 `json:"dataBytes,omitempty"`
UsagePercent float64 `json:"usagePercent"`
DegradedRatio float64 `json:"degradedRatio,omitempty"`
MisplacedRatio float64 `json:"misplacedRatio,omitempty"`
ReadBytesPerSec uint64 `json:"readBytesPerSec,omitempty"`
WriteBytesPerSec uint64 `json:"writeBytesPerSec,omitempty"`
ReadOpsPerSec uint64 `json:"readOpsPerSec,omitempty"`
WriteOpsPerSec uint64 `json:"writeOpsPerSec,omitempty"`
}
// Pool represents a Ceph pool.
@@ -106,16 +115,16 @@ type Pool struct {
// ServiceInfo represents a Ceph service summary.
type ServiceInfo struct {
Type string `json:"type"` // mon, mgr, osd, mds, rgw
Running int `json:"running"`
Total int `json:"total"`
Daemons []string `json:"daemons,omitempty"`
Type string `json:"type"` // mon, mgr, osd, mds, rgw
Running int `json:"running"`
Total int `json:"total"`
Daemons []string `json:"daemons,omitempty"`
}
// IsAvailable checks if the ceph CLI is available on the system.
func IsAvailable(ctx context.Context) bool {
cmd := exec.CommandContext(ctx, "which", "ceph")
return cmd.Run() == nil
_, _, err := commandRunner(ctx, "which", "ceph")
return err == nil
}
// Collect gathers Ceph cluster status using the ceph CLI.
@@ -160,17 +169,13 @@ func runCephCommand(ctx context.Context, args ...string) ([]byte, error) {
cmdCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "ceph", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("ceph %s failed: %w (stderr: %s)",
strings.Join(args, " "), err, stderr.String())
stdout, stderr, err := commandRunner(cmdCtx, "ceph", args...)
if err != nil {
return nil, fmt.Errorf("ceph %s failed: %w (stderr: %s)",
strings.Join(args, " "), err, string(stderr))
}
return stdout.Bytes(), nil
return stdout, nil
}
// parseStatus parses the output of `ceph status --format json`.
@@ -178,8 +183,8 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
var raw struct {
FSID string `json:"fsid"`
Health struct {
Status string `json:"status"`
Checks map[string]struct {
Status string `json:"status"`
Checks map[string]struct {
Severity string `json:"severity"`
Summary struct {
Message string `json:"message"`
@@ -198,10 +203,10 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
} `json:"mons"`
} `json:"monmap"`
MgrMap struct {
Available bool `json:"available"`
NumActive int `json:"num_active_name,omitempty"`
Available bool `json:"available"`
NumActive int `json:"num_active_name,omitempty"`
ActiveName string `json:"active_name"`
Standbys []struct {
Standbys []struct {
Name string `json:"name"`
} `json:"standbys"`
} `json:"mgrmap"`
@@ -212,17 +217,17 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
NumIn int `json:"num_in_osds"`
} `json:"osdmap"`
PGMap struct {
NumPGs int `json:"num_pgs"`
BytesTotal uint64 `json:"bytes_total"`
BytesUsed uint64 `json:"bytes_used"`
BytesAvail uint64 `json:"bytes_avail"`
DataBytes uint64 `json:"data_bytes"`
DegradedRatio float64 `json:"degraded_ratio"`
MisplacedRatio float64 `json:"misplaced_ratio"`
ReadBytesPerSec uint64 `json:"read_bytes_sec"`
WriteBytesPerSec uint64 `json:"write_bytes_sec"`
ReadOpsPerSec uint64 `json:"read_op_per_sec"`
WriteOpsPerSec uint64 `json:"write_op_per_sec"`
NumPGs int `json:"num_pgs"`
BytesTotal uint64 `json:"bytes_total"`
BytesUsed uint64 `json:"bytes_used"`
BytesAvail uint64 `json:"bytes_avail"`
DataBytes uint64 `json:"data_bytes"`
DegradedRatio float64 `json:"degraded_ratio"`
MisplacedRatio float64 `json:"misplaced_ratio"`
ReadBytesPerSec uint64 `json:"read_bytes_sec"`
WriteBytesPerSec uint64 `json:"write_bytes_sec"`
ReadOpsPerSec uint64 `json:"read_op_per_sec"`
WriteOpsPerSec uint64 `json:"write_op_per_sec"`
} `json:"pgmap"`
}
@@ -255,13 +260,13 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
NumOut: raw.OSDMap.NumOSD - raw.OSDMap.NumIn,
},
PGMap: PGMap{
NumPGs: raw.PGMap.NumPGs,
BytesTotal: raw.PGMap.BytesTotal,
BytesUsed: raw.PGMap.BytesUsed,
BytesAvailable: raw.PGMap.BytesAvail,
DataBytes: raw.PGMap.DataBytes,
DegradedRatio: raw.PGMap.DegradedRatio,
MisplacedRatio: raw.PGMap.MisplacedRatio,
NumPGs: raw.PGMap.NumPGs,
BytesTotal: raw.PGMap.BytesTotal,
BytesUsed: raw.PGMap.BytesUsed,
BytesAvailable: raw.PGMap.BytesAvail,
DataBytes: raw.PGMap.DataBytes,
DegradedRatio: raw.PGMap.DegradedRatio,
MisplacedRatio: raw.PGMap.MisplacedRatio,
ReadBytesPerSec: raw.PGMap.ReadBytesPerSec,
WriteBytesPerSec: raw.PGMap.WriteBytesPerSec,
ReadOpsPerSec: raw.PGMap.ReadOpsPerSec,

View File

@@ -1,9 +1,32 @@
package ceph
import (
"context"
"errors"
"strings"
"testing"
)
func withCommandRunner(t *testing.T, fn func(ctx context.Context, name string, args ...string) ([]byte, []byte, error)) {
t.Helper()
orig := commandRunner
commandRunner = fn
t.Cleanup(func() { commandRunner = orig })
}
func TestCommandRunner_Default(t *testing.T) {
stdout, stderr, err := commandRunner(context.Background(), "sh", "-c", "true")
if err != nil {
t.Fatalf("commandRunner error: %v", err)
}
if len(stdout) != 0 {
t.Fatalf("unexpected stdout: %q", string(stdout))
}
if len(stderr) != 0 {
t.Fatalf("unexpected stderr: %q", string(stderr))
}
}
func TestParseStatus(t *testing.T) {
data := []byte(`{
"fsid":"fsid-123",
@@ -70,6 +93,68 @@ func TestParseStatus(t *testing.T) {
}
}
func TestIsAvailable(t *testing.T) {
t.Run("available", func(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name != "which" || len(args) != 1 || args[0] != "ceph" {
t.Fatalf("unexpected command: %s %v", name, args)
}
return nil, nil, nil
})
if !IsAvailable(context.Background()) {
t.Fatalf("expected available")
}
})
t.Run("missing", func(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
return nil, nil, errors.New("missing")
})
if IsAvailable(context.Background()) {
t.Fatalf("expected unavailable")
}
})
}
func TestRunCephCommand(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name != "ceph" {
t.Fatalf("unexpected command name: %s", name)
}
if len(args) < 1 || args[0] != "status" {
t.Fatalf("unexpected args: %v", args)
}
return []byte(`{"ok":true}`), nil, nil
})
out, err := runCephCommand(context.Background(), "status", "--format", "json")
if err != nil {
t.Fatalf("runCephCommand error: %v", err)
}
if string(out) != `{"ok":true}` {
t.Fatalf("unexpected output: %s", string(out))
}
}
func TestRunCephCommandError(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
return nil, []byte("bad"), errors.New("boom")
})
_, err := runCephCommand(context.Background(), "status", "--format", "json")
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "ceph status --format json failed") {
t.Fatalf("unexpected error message: %v", err)
}
if !strings.Contains(err.Error(), "stderr: bad") {
t.Fatalf("expected stderr in error: %v", err)
}
}
func TestParseStatusInvalidJSON(t *testing.T) {
_, err := parseStatus([]byte(`{not-json}`))
if err == nil {
@@ -104,6 +189,181 @@ func TestParseDFInvalidJSON(t *testing.T) {
}
}
func TestCollect_NotAvailable(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name == "which" {
return nil, nil, errors.New("missing")
}
t.Fatalf("unexpected command: %s %v", name, args)
return nil, nil, nil
})
status, err := Collect(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status != nil {
t.Fatalf("expected nil status")
}
}
func TestCollect_StatusError(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name == "which" {
return nil, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "status" {
return nil, []byte("boom"), errors.New("status failed")
}
t.Fatalf("unexpected command: %s %v", name, args)
return nil, nil, nil
})
status, err := Collect(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status != nil {
t.Fatalf("expected nil status")
}
}
func TestCollect_ParseStatusError(t *testing.T) {
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name == "which" {
return nil, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "status" {
return []byte(`{not-json}`), nil, nil
}
t.Fatalf("unexpected command: %s %v", name, args)
return nil, nil, nil
})
_, err := Collect(context.Background())
if err == nil {
t.Fatalf("expected parse error")
}
}
func TestCollect_DFCommandError(t *testing.T) {
statusJSON := []byte(`{
"fsid":"fsid-1",
"health":{"status":"HEALTH_OK","checks":{}},
"monmap":{"epoch":1,"mons":[]},
"mgrmap":{"available":false,"active_name":"","standbys":[]},
"osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0},
"pgmap":{"num_pgs":0,"bytes_total":100,"bytes_used":50,"bytes_avail":50,
"data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0,
"read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0}
}`)
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name == "which" {
return nil, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "status" {
return statusJSON, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "df" {
return nil, []byte("df failed"), errors.New("df error")
}
t.Fatalf("unexpected command: %s %v", name, args)
return nil, nil, nil
})
status, err := Collect(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status == nil {
t.Fatalf("expected status")
}
if status.PGMap.UsagePercent == 0 {
t.Fatalf("expected usage percent from status")
}
}
func TestCollect_DFParseError(t *testing.T) {
statusJSON := []byte(`{
"fsid":"fsid-1",
"health":{"status":"HEALTH_OK","checks":{}},
"monmap":{"epoch":1,"mons":[]},
"mgrmap":{"available":false,"active_name":"","standbys":[]},
"osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0},
"pgmap":{"num_pgs":0,"bytes_total":0,"bytes_used":0,"bytes_avail":0,
"data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0,
"read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0}
}`)
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name == "which" {
return nil, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "status" {
return statusJSON, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "df" {
return []byte(`{not-json}`), nil, nil
}
t.Fatalf("unexpected command: %s %v", name, args)
return nil, nil, nil
})
status, err := Collect(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status == nil {
t.Fatalf("expected status")
}
if status.PGMap.UsagePercent != 0 {
t.Fatalf("expected usage percent to remain 0, got %v", status.PGMap.UsagePercent)
}
}
func TestCollect_UsagePercentFromDF(t *testing.T) {
statusJSON := []byte(`{
"fsid":"fsid-1",
"health":{"status":"HEALTH_OK","checks":{}},
"monmap":{"epoch":1,"mons":[]},
"mgrmap":{"available":false,"active_name":"","standbys":[]},
"osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0},
"pgmap":{"num_pgs":0,"bytes_total":0,"bytes_used":0,"bytes_avail":0,
"data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0,
"read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0}
}`)
dfJSON := []byte(`{
"stats":{"total_bytes":1000,"total_used_bytes":500,"percent_used":0.5},
"pools":[]
}`)
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
if name == "which" {
return nil, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "status" {
return statusJSON, nil, nil
}
if name == "ceph" && len(args) > 0 && args[0] == "df" {
return dfJSON, nil, nil
}
t.Fatalf("unexpected command: %s %v", name, args)
return nil, nil, nil
})
status, err := Collect(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if status == nil {
t.Fatalf("expected status")
}
if status.PGMap.UsagePercent != 50 {
t.Fatalf("expected usage percent from df, got %v", status.PGMap.UsagePercent)
}
if status.CollectedAt.IsZero() {
t.Fatalf("expected collected timestamp set")
}
}
func TestBoolToInt(t *testing.T) {
if boolToInt(true) != 1 {
t.Fatalf("expected boolToInt(true)=1")

View File

@@ -14,6 +14,16 @@ import (
"github.com/rs/zerolog/log"
)
var defaultDataDirFn = utils.GetDataDir
var legacyKeyPath = "/etc/pulse/.encryption.key"
var randReader = rand.Reader
var newCipher = aes.NewCipher
var newGCM = cipher.NewGCM
// CryptoManager handles encryption/decryption of sensitive data
type CryptoManager struct {
key []byte
@@ -23,7 +33,7 @@ type CryptoManager struct {
// NewCryptoManagerAt creates a new crypto manager with an explicit data directory override.
func NewCryptoManagerAt(dataDir string) (*CryptoManager, error) {
if dataDir == "" {
dataDir = utils.GetDataDir()
dataDir = defaultDataDirFn()
}
keyPath := filepath.Join(dataDir, ".encryption.key")
@@ -41,11 +51,12 @@ func NewCryptoManagerAt(dataDir string) (*CryptoManager, error) {
// getOrCreateKeyAt gets the encryption key or creates one if it doesn't exist
func getOrCreateKeyAt(dataDir string) ([]byte, error) {
if dataDir == "" {
dataDir = utils.GetDataDir()
dataDir = defaultDataDirFn()
}
keyPath := filepath.Join(dataDir, ".encryption.key")
oldKeyPath := "/etc/pulse/.encryption.key"
oldKeyPath := legacyKeyPath
oldKeyDir := filepath.Dir(oldKeyPath)
log.Debug().
Str("dataDir", dataDir).
@@ -78,7 +89,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) {
// Check for key in old location and migrate if found (only if paths differ)
// CRITICAL: This code deletes the encryption key at oldKeyPath after migrating it.
// Adding extensive logging to diagnose recurring key deletion bug.
if dataDir != "/etc/pulse" && keyPath != oldKeyPath {
if dataDir != oldKeyDir && keyPath != oldKeyPath {
log.Warn().
Str("dataDir", dataDir).
Str("keyPath", keyPath).
@@ -135,7 +146,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) {
log.Debug().
Str("dataDir", dataDir).
Str("keyPath", keyPath).
Bool("sameAsOldPath", dataDir == "/etc/pulse").
Bool("sameAsOldPath", dataDir == oldKeyDir).
Msg("Skipping key migration check (dataDir is /etc/pulse or paths match)")
}
@@ -171,7 +182,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) {
// Generate new key (only if no encrypted data exists)
key := make([]byte, 32) // AES-256
if _, err := io.ReadFull(rand.Reader, key); err != nil {
if _, err := io.ReadFull(randReader, key); err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
@@ -205,18 +216,18 @@ func (c *CryptoManager) Encrypt(plaintext []byte) ([]byte, error) {
}
}
block, err := aes.NewCipher(c.key)
block, err := newCipher(c.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
gcm, err := newGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
if _, err := io.ReadFull(randReader, nonce); err != nil {
return nil, err
}
@@ -226,12 +237,12 @@ func (c *CryptoManager) Encrypt(plaintext []byte) ([]byte, error) {
// Decrypt decrypts data using AES-GCM
func (c *CryptoManager) Decrypt(ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(c.key)
block, err := newCipher(c.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
gcm, err := newGCM(block)
if err != nil {
return nil, err
}

View File

@@ -2,11 +2,51 @@ package crypto
import (
"bytes"
"crypto/cipher"
"encoding/base64"
"errors"
"io"
"os"
"path/filepath"
"testing"
)
type errReader struct {
err error
}
func (e errReader) Read(p []byte) (int, error) {
return 0, e.err
}
func withDefaultDataDir(t *testing.T, dir string) {
t.Helper()
orig := defaultDataDirFn
defaultDataDirFn = func() string { return dir }
t.Cleanup(func() { defaultDataDirFn = orig })
}
func withLegacyKeyPath(t *testing.T, path string) {
t.Helper()
orig := legacyKeyPath
legacyKeyPath = path
t.Cleanup(func() { legacyKeyPath = orig })
}
func withRandReader(t *testing.T, r io.Reader) {
t.Helper()
orig := randReader
randReader = r
t.Cleanup(func() { randReader = orig })
}
func withNewGCM(t *testing.T, fn func(cipher.Block) (cipher.AEAD, error)) {
t.Helper()
orig := newGCM
newGCM = fn
t.Cleanup(func() { newGCM = orig })
}
func TestEncryptDecrypt(t *testing.T) {
// Create a temp directory for the test
tmpDir := t.TempDir()
@@ -149,6 +189,33 @@ func TestEncryptionKeyFilePermissions(t *testing.T) {
}
}
func TestNewCryptoManagerAt_DefaultDataDir(t *testing.T) {
tmpDir := t.TempDir()
withDefaultDataDir(t, tmpDir)
cm, err := NewCryptoManagerAt("")
if err != nil {
t.Fatalf("NewCryptoManagerAt() error: %v", err)
}
if cm.keyPath != filepath.Join(tmpDir, ".encryption.key") {
t.Fatalf("keyPath = %q, want %q", cm.keyPath, filepath.Join(tmpDir, ".encryption.key"))
}
}
func TestNewCryptoManagerAt_KeyError(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
err := os.WriteFile(filepath.Join(tmpDir, "nodes.enc"), []byte("data"), 0600)
if err != nil {
t.Fatalf("Failed to create encrypted file: %v", err)
}
_, err = NewCryptoManagerAt(tmpDir)
if err == nil {
t.Fatal("Expected error when encrypted data exists without a key")
}
}
func TestDecryptInvalidData(t *testing.T) {
tmpDir := t.TempDir()
cm, err := NewCryptoManagerAt(tmpDir)
@@ -175,6 +242,217 @@ func TestDecryptInvalidData(t *testing.T) {
}
}
func TestGetOrCreateKeyAt_InvalidBase64(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
keyPath := filepath.Join(tmpDir, ".encryption.key")
if err := os.WriteFile(keyPath, []byte("not-base64"), 0600); err != nil {
t.Fatalf("Failed to write key file: %v", err)
}
key, err := getOrCreateKeyAt(tmpDir)
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if len(key) != 32 {
t.Fatalf("expected 32-byte key, got %d", len(key))
}
}
func TestGetOrCreateKeyAt_DefaultDataDir(t *testing.T) {
tmpDir := t.TempDir()
withDefaultDataDir(t, tmpDir)
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
key, err := getOrCreateKeyAt("")
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if len(key) != 32 {
t.Fatalf("expected 32-byte key, got %d", len(key))
}
}
func TestGetOrCreateKeyAt_InvalidLength(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
shortKey := make([]byte, 16)
for i := range shortKey {
shortKey[i] = byte(i)
}
encoded := base64.StdEncoding.EncodeToString(shortKey)
if err := os.WriteFile(filepath.Join(tmpDir, ".encryption.key"), []byte(encoded), 0600); err != nil {
t.Fatalf("Failed to write key file: %v", err)
}
key, err := getOrCreateKeyAt(tmpDir)
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if len(key) != 32 {
t.Fatalf("expected 32-byte key, got %d", len(key))
}
}
func TestGetOrCreateKeyAt_SkipMigrationWhenPathsMatch(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(tmpDir, ".encryption.key"))
key, err := getOrCreateKeyAt(tmpDir)
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if len(key) != 32 {
t.Fatalf("expected 32-byte key, got %d", len(key))
}
}
func TestGetOrCreateKeyAt_MigrateSuccess(t *testing.T) {
legacyDir := t.TempDir()
legacyPath := filepath.Join(legacyDir, ".encryption.key")
withLegacyKeyPath(t, legacyPath)
oldKey := make([]byte, 32)
for i := range oldKey {
oldKey[i] = byte(i)
}
encoded := base64.StdEncoding.EncodeToString(oldKey)
if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil {
t.Fatalf("Failed to write legacy key: %v", err)
}
newDir := t.TempDir()
key, err := getOrCreateKeyAt(newDir)
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if !bytes.Equal(key, oldKey) {
t.Fatalf("migrated key mismatch")
}
contents, err := os.ReadFile(filepath.Join(newDir, ".encryption.key"))
if err != nil {
t.Fatalf("Failed to read migrated key: %v", err)
}
if string(contents) != encoded {
t.Fatalf("migrated key contents mismatch")
}
}
func TestGetOrCreateKeyAt_MigrateMkdirError(t *testing.T) {
legacyDir := t.TempDir()
legacyPath := filepath.Join(legacyDir, ".encryption.key")
withLegacyKeyPath(t, legacyPath)
oldKey := make([]byte, 32)
for i := range oldKey {
oldKey[i] = byte(i)
}
encoded := base64.StdEncoding.EncodeToString(oldKey)
if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil {
t.Fatalf("Failed to write legacy key: %v", err)
}
tmpDir := t.TempDir()
dataFile := filepath.Join(tmpDir, "datafile")
if err := os.WriteFile(dataFile, []byte("x"), 0600); err != nil {
t.Fatalf("Failed to write data file: %v", err)
}
key, err := getOrCreateKeyAt(dataFile)
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if !bytes.Equal(key, oldKey) {
t.Fatalf("expected legacy key on mkdir error")
}
}
func TestGetOrCreateKeyAt_MigrateWriteError(t *testing.T) {
legacyDir := t.TempDir()
legacyPath := filepath.Join(legacyDir, ".encryption.key")
withLegacyKeyPath(t, legacyPath)
oldKey := make([]byte, 32)
for i := range oldKey {
oldKey[i] = byte(i)
}
encoded := base64.StdEncoding.EncodeToString(oldKey)
if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil {
t.Fatalf("Failed to write legacy key: %v", err)
}
newDir := t.TempDir()
keyPath := filepath.Join(newDir, ".encryption.key")
if err := os.MkdirAll(keyPath, 0700); err != nil {
t.Fatalf("Failed to create key path dir: %v", err)
}
key, err := getOrCreateKeyAt(newDir)
if err != nil {
t.Fatalf("getOrCreateKeyAt() error: %v", err)
}
if !bytes.Equal(key, oldKey) {
t.Fatalf("expected legacy key on write error")
}
}
func TestGetOrCreateKeyAt_EncryptedDataExists(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
if err := os.WriteFile(filepath.Join(tmpDir, "nodes.enc"), []byte("data"), 0600); err != nil {
t.Fatalf("Failed to write encrypted file: %v", err)
}
_, err := getOrCreateKeyAt(tmpDir)
if err == nil {
t.Fatal("Expected error when encrypted data exists")
}
}
func TestGetOrCreateKeyAt_RandReaderError(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
withRandReader(t, errReader{err: errors.New("read failed")})
_, err := getOrCreateKeyAt(tmpDir)
if err == nil {
t.Fatal("Expected error from rand reader")
}
}
func TestGetOrCreateKeyAt_CreateDirError(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
dataFile := filepath.Join(tmpDir, "datafile")
if err := os.WriteFile(dataFile, []byte("x"), 0600); err != nil {
t.Fatalf("Failed to write data file: %v", err)
}
_, err := getOrCreateKeyAt(dataFile)
if err == nil {
t.Fatal("Expected error when creating directory")
}
}
func TestGetOrCreateKeyAt_SaveKeyError(t *testing.T) {
tmpDir := t.TempDir()
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
keyPath := filepath.Join(tmpDir, ".encryption.key")
if err := os.MkdirAll(keyPath, 0700); err != nil {
t.Fatalf("Failed to create key path dir: %v", err)
}
_, err := getOrCreateKeyAt(tmpDir)
if err == nil {
t.Fatal("Expected error when saving key")
}
}
func TestDecryptStringInvalidBase64(t *testing.T) {
tmpDir := t.TempDir()
cm, err := NewCryptoManagerAt(tmpDir)
@@ -216,6 +494,57 @@ func TestEncryptionUniqueness(t *testing.T) {
}
}
func TestEncryptInvalidKey(t *testing.T) {
cm := &CryptoManager{key: []byte("short")}
if _, err := cm.Encrypt([]byte("data")); err == nil {
t.Fatal("Expected error for invalid key length")
}
}
func TestDecryptInvalidKey(t *testing.T) {
cm := &CryptoManager{key: []byte("short")}
if _, err := cm.Decrypt([]byte("data")); err == nil {
t.Fatal("Expected error for invalid key length")
}
}
func TestEncryptNonceReadError(t *testing.T) {
withRandReader(t, errReader{err: errors.New("nonce read error")})
cm := &CryptoManager{key: make([]byte, 32)}
if _, err := cm.Encrypt([]byte("data")); err == nil {
t.Fatal("Expected error reading nonce")
}
}
func TestEncryptDecryptGCMError(t *testing.T) {
withNewGCM(t, func(cipher.Block) (cipher.AEAD, error) {
return nil, errors.New("gcm error")
})
cm := &CryptoManager{key: make([]byte, 32)}
if _, err := cm.Encrypt([]byte("data")); err == nil {
t.Fatal("Expected Encrypt error from GCM")
}
if _, err := cm.Decrypt([]byte("data")); err == nil {
t.Fatal("Expected Decrypt error from GCM")
}
}
func TestEncryptStringError(t *testing.T) {
cm := &CryptoManager{key: []byte("short")}
if _, err := cm.EncryptString("data"); err == nil {
t.Fatal("Expected EncryptString error")
}
}
func TestDecryptStringError(t *testing.T) {
cm := &CryptoManager{key: make([]byte, 32)}
encoded := base64.StdEncoding.EncodeToString([]byte("short"))
if _, err := cm.DecryptString(encoded); err == nil {
t.Fatal("Expected DecryptString error")
}
}
func TestNewCryptoManagerRefusesOrphanedData(t *testing.T) {
// Skip if production key exists - migration code will always find and use it
if _, err := os.Stat("/etc/pulse/.encryption.key"); err == nil {

View File

@@ -44,6 +44,28 @@ var (
defaultTimeFmt = time.RFC3339
)
var (
nowFn = time.Now
isTerminalFn = term.IsTerminal
mkdirAllFn = os.MkdirAll
openFileFn = os.OpenFile
openFn = os.Open
statFn = os.Stat
readDirFn = os.ReadDir
renameFn = os.Rename
removeFn = os.Remove
copyFn = io.Copy
gzipNewWriterFn = gzip.NewWriter
statFileFn = func(file *os.File) (os.FileInfo, error) { return file.Stat() }
closeFileFn = func(file *os.File) error { return file.Close() }
compressFn = compressAndRemove
)
var (
defaultStatFileFn = statFileFn
defaultCloseFileFn = closeFileFn
)
func init() {
baseLogger = zerolog.New(baseWriter).With().Timestamp().Logger()
log.Logger = baseLogger
@@ -137,7 +159,7 @@ func isTerminal(file *os.File) bool {
if file == nil {
return false
}
return term.IsTerminal(int(file.Fd()))
return isTerminalFn(int(file.Fd()))
}
type rollingFileWriter struct {
@@ -157,7 +179,7 @@ func newRollingFileWriter(cfg Config) (io.Writer, error) {
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
if err := mkdirAllFn(dir, 0o755); err != nil {
return nil, fmt.Errorf("create log directory: %w", err)
}
@@ -205,13 +227,13 @@ func (w *rollingFileWriter) openOrCreateLocked() error {
return nil
}
file, err := os.OpenFile(w.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
file, err := openFileFn(w.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("open log file: %w", err)
}
w.file = file
info, err := file.Stat()
info, err := statFileFn(file)
if err != nil {
w.currentSize = 0
return nil
@@ -225,11 +247,11 @@ func (w *rollingFileWriter) rotateLocked() error {
return err
}
if _, err := os.Stat(w.path); err == nil {
rotated := fmt.Sprintf("%s.%s", w.path, time.Now().Format("20060102-150405"))
if err := os.Rename(w.path, rotated); err == nil {
if _, err := statFn(w.path); err == nil {
rotated := fmt.Sprintf("%s.%s", w.path, nowFn().Format("20060102-150405"))
if err := renameFn(w.path, rotated); err == nil {
if w.compress {
go compressAndRemove(rotated)
go compressFn(rotated)
}
}
}
@@ -242,7 +264,7 @@ func (w *rollingFileWriter) closeLocked() error {
if w.file == nil {
return nil
}
err := w.file.Close()
err := closeFileFn(w.file)
w.file = nil
w.currentSize = 0
return err
@@ -256,9 +278,9 @@ func (w *rollingFileWriter) cleanupOldFiles() {
dir := filepath.Dir(w.path)
base := filepath.Base(w.path)
prefix := base + "."
cutoff := time.Now().Add(-w.maxAge)
cutoff := nowFn().Add(-w.maxAge)
entries, err := os.ReadDir(dir)
entries, err := readDirFn(dir)
if err != nil {
return
}
@@ -273,26 +295,26 @@ func (w *rollingFileWriter) cleanupOldFiles() {
continue
}
if info.ModTime().Before(cutoff) {
_ = os.Remove(filepath.Join(dir, name))
_ = removeFn(filepath.Join(dir, name))
}
}
}
func compressAndRemove(path string) {
in, err := os.Open(path)
in, err := openFn(path)
if err != nil {
return
}
defer in.Close()
outPath := path + ".gz"
out, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
out, err := openFileFn(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
if err != nil {
return
}
gw := gzip.NewWriter(out)
if _, err = io.Copy(gw, in); err != nil {
gw := gzipNewWriterFn(out)
if _, err = copyFn(gw, in); err != nil {
gw.Close()
out.Close()
return
@@ -302,5 +324,5 @@ func compressAndRemove(path string) {
return
}
out.Close()
_ = os.Remove(path)
_ = removeFn(path)
}

View File

@@ -2,13 +2,19 @@ package logging
import (
"bytes"
"compress/gzip"
"errors"
"io"
"os"
"path/filepath"
"reflect"
"sync"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/term"
)
func resetLoggingState() {
@@ -21,6 +27,20 @@ func resetLoggingState() {
log.Logger = baseLogger
zerolog.TimeFieldFormat = defaultTimeFmt
zerolog.SetGlobalLevel(zerolog.InfoLevel)
nowFn = time.Now
isTerminalFn = term.IsTerminal
mkdirAllFn = os.MkdirAll
openFileFn = os.OpenFile
openFn = os.Open
statFn = os.Stat
readDirFn = os.ReadDir
renameFn = os.Rename
removeFn = os.Remove
copyFn = io.Copy
gzipNewWriterFn = gzip.NewWriter
statFileFn = defaultStatFileFn
closeFileFn = defaultCloseFileFn
compressFn = compressAndRemove
}
func TestInitJSONFormatSetsLevelAndComponent(t *testing.T) {
@@ -267,6 +287,393 @@ func TestSelectWriter(t *testing.T) {
}
}
func TestSelectWriterAutoTerminal(t *testing.T) {
t.Cleanup(resetLoggingState)
isTerminalFn = func(int) bool { return true }
w := selectWriter("auto")
if _, ok := w.(zerolog.ConsoleWriter); !ok {
t.Fatalf("expected console writer, got %#v", w)
}
}
func TestSelectWriterDefault(t *testing.T) {
t.Cleanup(resetLoggingState)
w := selectWriter("unknown")
if w != os.Stderr {
t.Fatalf("expected default writer to be os.Stderr, got %#v", w)
}
}
func TestIsTerminalNil(t *testing.T) {
t.Cleanup(resetLoggingState)
if isTerminal(nil) {
t.Fatal("expected nil file to report false")
}
}
func TestNewRollingFileWriter_EmptyPath(t *testing.T) {
t.Cleanup(resetLoggingState)
writer, err := newRollingFileWriter(Config{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if writer != nil {
t.Fatalf("expected nil writer, got %#v", writer)
}
}
func TestNewRollingFileWriter_MkdirError(t *testing.T) {
t.Cleanup(resetLoggingState)
mkdirAllFn = func(string, os.FileMode) error {
return errors.New("mkdir failed")
}
_, err := newRollingFileWriter(Config{FilePath: "/tmp/logs/test.log"})
if err == nil {
t.Fatal("expected error from mkdir")
}
}
func TestInitFileWriterError(t *testing.T) {
t.Cleanup(resetLoggingState)
mkdirAllFn = func(string, os.FileMode) error {
return errors.New("mkdir failed")
}
Init(Config{
Format: "json",
FilePath: "/tmp/logs/test.log",
})
}
func TestNewRollingFileWriter_DefaultMaxSize(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
writer, err := newRollingFileWriter(Config{
FilePath: filepath.Join(dir, "app.log"),
MaxSizeMB: 0,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
w, ok := writer.(*rollingFileWriter)
if !ok {
t.Fatalf("expected rollingFileWriter, got %#v", writer)
}
if w.maxBytes != 100*1024*1024 {
t.Fatalf("expected default max bytes, got %d", w.maxBytes)
}
_ = w.closeLocked()
}
func TestNewRollingFileWriter_OpenError(t *testing.T) {
t.Cleanup(resetLoggingState)
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
return nil, errors.New("open failed")
}
_, err := newRollingFileWriter(Config{FilePath: filepath.Join(t.TempDir(), "app.log")})
if err == nil {
t.Fatal("expected error from openOrCreateLocked")
}
}
func TestOpenOrCreateLocked_StatError(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
w := &rollingFileWriter{path: filepath.Join(dir, "app.log")}
statFileFn = func(*os.File) (os.FileInfo, error) {
return nil, errors.New("stat failed")
}
if err := w.openOrCreateLocked(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if w.currentSize != 0 {
t.Fatalf("expected current size 0, got %d", w.currentSize)
}
_ = w.closeLocked()
}
func TestOpenOrCreateLocked_AlreadyOpen(t *testing.T) {
t.Cleanup(resetLoggingState)
file, err := os.CreateTemp(t.TempDir(), "log")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
w := &rollingFileWriter{path: file.Name(), file: file}
if err := w.openOrCreateLocked(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
_ = w.closeLocked()
}
func TestRollingFileWriter_WriteOpenError(t *testing.T) {
t.Cleanup(resetLoggingState)
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
return nil, errors.New("open failed")
}
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log")}
if _, err := w.Write([]byte("data")); err == nil {
t.Fatal("expected write error")
}
}
func TestRollingFileWriter_WriteRotateError(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
callCount := 0
openFileFn = func(name string, flag int, perm os.FileMode) (*os.File, error) {
callCount++
if callCount == 1 {
return os.OpenFile(name, flag, perm)
}
return nil, errors.New("open failed")
}
w := &rollingFileWriter{path: path, maxBytes: 1}
if _, err := w.Write([]byte("too big")); err == nil {
t.Fatal("expected rotate error")
}
}
func TestRollingFileWriter_RotateCompress(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
w := &rollingFileWriter{path: path, maxBytes: 1, compress: true}
if err := w.openOrCreateLocked(); err != nil {
t.Fatalf("openOrCreateLocked error: %v", err)
}
ch := make(chan string, 1)
compressFn = func(p string) { ch <- p }
if err := w.rotateLocked(); err != nil {
t.Fatalf("rotateLocked error: %v", err)
}
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal("expected compress to be triggered")
}
_ = w.closeLocked()
}
func TestRotateLockedCloseError(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
w := &rollingFileWriter{path: path, file: file}
closeFileFn = func(*os.File) error {
return errors.New("close failed")
}
if err := w.rotateLocked(); err == nil {
t.Fatal("expected close error")
}
_ = file.Close()
}
func TestCloseLocked(t *testing.T) {
t.Cleanup(resetLoggingState)
w := &rollingFileWriter{}
if err := w.closeLocked(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
file, err := os.CreateTemp(t.TempDir(), "log")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
w.file = file
if err := w.closeLocked(); err != nil {
t.Fatalf("unexpected close error: %v", err)
}
if w.file != nil {
t.Fatal("expected file to be cleared")
}
if w.currentSize != 0 {
t.Fatalf("expected size reset, got %d", w.currentSize)
}
}
func TestCleanupOldFilesNoMaxAge(t *testing.T) {
t.Cleanup(resetLoggingState)
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: 0}
w.cleanupOldFiles()
}
func TestCleanupOldFilesReadDirError(t *testing.T) {
t.Cleanup(resetLoggingState)
readDirFn = func(string) ([]os.DirEntry, error) {
return nil, errors.New("read dir failed")
}
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: time.Hour}
w.cleanupOldFiles()
}
func TestCleanupOldFilesInfoError(t *testing.T) {
t.Cleanup(resetLoggingState)
readDirFn = func(string) ([]os.DirEntry, error) {
return []os.DirEntry{errDirEntry{name: "app.log.20200101"}}, nil
}
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: time.Hour}
w.cleanupOldFiles()
}
func TestCleanupOldFilesRemovesOld(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
oldFile := filepath.Join(dir, "app.log.20200101-000000")
newFile := filepath.Join(dir, "app.log.20250101-000000")
otherFile := filepath.Join(dir, "other.log.20200101")
if err := os.WriteFile(oldFile, []byte("old"), 0600); err != nil {
t.Fatalf("failed to write old file: %v", err)
}
if err := os.WriteFile(newFile, []byte("new"), 0600); err != nil {
t.Fatalf("failed to write new file: %v", err)
}
if err := os.WriteFile(otherFile, []byte("other"), 0600); err != nil {
t.Fatalf("failed to write other file: %v", err)
}
fixedNow := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
nowFn = func() time.Time { return fixedNow }
if err := os.Chtimes(oldFile, fixedNow.Add(-48*time.Hour), fixedNow.Add(-48*time.Hour)); err != nil {
t.Fatalf("failed to set old file time: %v", err)
}
if err := os.Chtimes(newFile, fixedNow.Add(-time.Hour), fixedNow.Add(-time.Hour)); err != nil {
t.Fatalf("failed to set new file time: %v", err)
}
w := &rollingFileWriter{path: path, maxAge: 24 * time.Hour}
w.cleanupOldFiles()
if _, err := os.Stat(oldFile); !os.IsNotExist(err) {
t.Fatalf("expected old file to be removed")
}
if _, err := os.Stat(newFile); err != nil {
t.Fatalf("expected new file to remain: %v", err)
}
if _, err := os.Stat(otherFile); err != nil {
t.Fatalf("expected other file to remain: %v", err)
}
}
func TestStatFileFnDefault(t *testing.T) {
t.Cleanup(resetLoggingState)
file, err := os.CreateTemp(t.TempDir(), "log")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
t.Cleanup(func() { _ = file.Close() })
if _, err := statFileFn(file); err != nil {
t.Fatalf("statFileFn error: %v", err)
}
}
func TestCompressAndRemove(t *testing.T) {
t.Run("OpenError", func(t *testing.T) {
t.Cleanup(resetLoggingState)
openFn = func(string) (*os.File, error) {
return nil, errors.New("open failed")
}
compressAndRemove("/does/not/exist")
})
t.Run("OpenFileError", func(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
return nil, errors.New("open file failed")
}
compressAndRemove(path)
})
t.Run("CopyError", func(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
copyFn = func(io.Writer, io.Reader) (int64, error) {
return 0, errors.New("copy failed")
}
compressAndRemove(path)
})
t.Run("CloseError", func(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
errWriter := errWriteCloser{err: errors.New("write failed")}
gzipNewWriterFn = func(io.Writer) *gzip.Writer {
return gzip.NewWriter(errWriter)
}
copyFn = func(io.Writer, io.Reader) (int64, error) {
return 0, nil
}
compressAndRemove(path)
})
t.Run("Success", func(t *testing.T) {
t.Cleanup(resetLoggingState)
dir := t.TempDir()
path := filepath.Join(dir, "app.log")
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
compressAndRemove(path)
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Fatal("expected original file to be removed")
}
if _, err := os.Stat(path + ".gz"); err != nil {
t.Fatalf("expected gzip file to exist: %v", err)
}
})
}
// Test that the logging package doesn't panic under concurrent use
func TestConcurrentLogging(t *testing.T) {
t.Cleanup(resetLoggingState)
@@ -292,3 +699,24 @@ func TestConcurrentLogging(t *testing.T) {
t.Fatal("expected log output from concurrent logging")
}
}
type errDirEntry struct {
name string
}
func (e errDirEntry) Name() string { return e.name }
func (e errDirEntry) IsDir() bool { return false }
func (e errDirEntry) Type() os.FileMode {
return 0
}
func (e errDirEntry) Info() (os.FileInfo, error) {
return nil, errors.New("info error")
}
type errWriteCloser struct {
err error
}
func (e errWriteCloser) Write(p []byte) (int, error) {
return 0, e.err
}

View File

@@ -14,9 +14,13 @@ import (
// Pre-compiled regexes for performance (avoid recompilation on each call)
var (
mdDeviceRe = regexp.MustCompile(`^(md\d+)\s*:`)
slotRe = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(.+?)\s+(/dev/.+)$`)
speedRe = regexp.MustCompile(`speed=(\S+)`)
mdDeviceRe = regexp.MustCompile(`^(md\d+)\s*:`)
slotRe = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(.+?)\s+(/dev/.+)$`)
speedRe = regexp.MustCompile(`speed=(\S+)`)
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
return cmd.Output()
}
)
// CollectArrays discovers and collects status for all mdadm RAID arrays on the system.
@@ -56,8 +60,8 @@ func isMdadmAvailable(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "mdadm", "--version")
return cmd.Run() == nil
_, err := runCommandOutput(ctx, "mdadm", "--version")
return err == nil
}
// listArrayDevices scans /proc/mdstat to find all md devices
@@ -65,8 +69,7 @@ func listArrayDevices(ctx context.Context) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "cat", "/proc/mdstat")
output, err := cmd.Output()
output, err := runCommandOutput(ctx, "cat", "/proc/mdstat")
if err != nil {
return nil, fmt.Errorf("read /proc/mdstat: %w", err)
}
@@ -89,8 +92,7 @@ func collectArrayDetail(ctx context.Context, device string) (host.RAIDArray, err
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "mdadm", "--detail", device)
output, err := cmd.Output()
output, err := runCommandOutput(ctx, "mdadm", "--detail", device)
if err != nil {
return host.RAIDArray{}, fmt.Errorf("mdadm --detail %s: %w", device, err)
}
@@ -220,8 +222,7 @@ func getRebuildSpeed(device string) string {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "cat", "/proc/mdstat")
output, err := cmd.Output()
output, err := runCommandOutput(ctx, "cat", "/proc/mdstat")
if err != nil {
return ""
}

View File

@@ -1,11 +1,20 @@
package mdadm
import (
"context"
"errors"
"testing"
"github.com/rcourtman/pulse-go-rewrite/pkg/agents/host"
)
func withRunCommandOutput(t *testing.T, fn func(ctx context.Context, name string, args ...string) ([]byte, error)) {
t.Helper()
orig := runCommandOutput
runCommandOutput = fn
t.Cleanup(func() { runCommandOutput = orig })
}
func TestParseDetail(t *testing.T) {
tests := []struct {
name string
@@ -540,3 +549,239 @@ Consistency Policy : resync
})
}
}
func TestIsMdadmAvailable(t *testing.T) {
t.Run("available", func(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte("mdadm"), nil
})
if !isMdadmAvailable(context.Background()) {
t.Fatal("expected mdadm available")
}
})
t.Run("missing", func(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("missing")
})
if isMdadmAvailable(context.Background()) {
t.Fatal("expected mdadm unavailable")
}
})
}
func TestListArrayDevices(t *testing.T) {
mdstat := `Personalities : [raid1] [raid6]
md0 : active raid1 sdb1[1] sda1[0]
md1 : active raid6 sdc1[2] sdb1[1] sda1[0]
unused devices: <none>`
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte(mdstat), nil
})
devices, err := listArrayDevices(context.Background())
if err != nil {
t.Fatalf("listArrayDevices error: %v", err)
}
if len(devices) != 2 || devices[0] != "/dev/md0" || devices[1] != "/dev/md1" {
t.Fatalf("unexpected devices: %v", devices)
}
}
func TestListArrayDevicesError(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("read failed")
})
if _, err := listArrayDevices(context.Background()); err == nil {
t.Fatal("expected error")
}
}
func TestCollectArrayDetailError(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("detail failed")
})
if _, err := collectArrayDetail(context.Background(), "/dev/md0"); err == nil {
t.Fatal("expected error")
}
}
func TestCollectArraysNotAvailable(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("missing")
})
arrays, err := CollectArrays(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if arrays != nil {
t.Fatalf("expected nil arrays, got %v", arrays)
}
}
func TestCollectArraysListError(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
if name == "mdadm" {
return []byte("mdadm"), nil
}
return nil, errors.New("read failed")
})
if _, err := CollectArrays(context.Background()); err == nil {
t.Fatal("expected error")
}
}
func TestCollectArraysNoDevices(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
if name == "mdadm" {
return []byte("mdadm"), nil
}
return []byte("unused devices: <none>"), nil
})
arrays, err := CollectArrays(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if arrays != nil {
t.Fatalf("expected nil arrays, got %v", arrays)
}
}
func TestCollectArraysSkipsDetailError(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
switch name {
case "mdadm":
if len(args) > 0 && args[0] == "--version" {
return []byte("mdadm"), nil
}
return nil, errors.New("detail failed")
case "cat":
return []byte("md0 : active raid1 sda1[0]"), nil
default:
return nil, errors.New("unexpected")
}
})
arrays, err := CollectArrays(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(arrays) != 0 {
t.Fatalf("expected empty arrays, got %v", arrays)
}
}
func TestCollectArraysSuccess(t *testing.T) {
detail := `/dev/md0:
Raid Level : raid1
State : clean
Total Devices : 2
Active Devices : 2
Working Devices : 2
Failed Devices : 0
Spare Devices : 0
Number Major Minor RaidDevice State
0 8 1 0 active sync /dev/sda1`
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
switch name {
case "mdadm":
if len(args) > 0 && args[0] == "--version" {
return []byte("mdadm"), nil
}
return []byte(detail), nil
case "cat":
return []byte("md0 : active raid1 sda1[0]"), nil
default:
return nil, errors.New("unexpected")
}
})
arrays, err := CollectArrays(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(arrays) != 1 || arrays[0].Device != "/dev/md0" {
t.Fatalf("unexpected arrays: %v", arrays)
}
}
func TestGetRebuildSpeed(t *testing.T) {
mdstat := `md0 : active raid1 sda1[0] sdb1[1]
[>....................] recovery = 12.6% (37043392/293039104) finish=127.5min speed=33440K/sec
`
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte(mdstat), nil
})
if speed := getRebuildSpeed("/dev/md0"); speed != "33440K/sec" {
t.Fatalf("unexpected speed: %s", speed)
}
}
func TestGetRebuildSpeedNoMatch(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte("md0 : active raid1 sda1[0]"), nil
})
if speed := getRebuildSpeed("/dev/md0"); speed != "" {
t.Fatalf("expected empty speed, got %s", speed)
}
}
func TestGetRebuildSpeedError(t *testing.T) {
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("read failed")
})
if speed := getRebuildSpeed("/dev/md0"); speed != "" {
t.Fatalf("expected empty speed, got %s", speed)
}
}
func TestParseDetailSetsRebuildSpeed(t *testing.T) {
output := `/dev/md0:
Raid Level : raid1
State : clean
Rebuild Status : 12% complete
Number Major Minor RaidDevice State
0 8 1 0 active sync /dev/sda1`
mdstat := `md0 : active raid1 sda1[0]
[>....................] recovery = 12.6% (37043392/293039104) finish=127.5min speed=1234K/sec
`
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte(mdstat), nil
})
array, err := parseDetail("/dev/md0", output)
if err != nil {
t.Fatalf("parseDetail error: %v", err)
}
if array.RebuildSpeed != "1234K/sec" {
t.Fatalf("expected rebuild speed, got %s", array.RebuildSpeed)
}
}
func TestGetRebuildSpeedSectionExit(t *testing.T) {
mdstat := `md0 : active raid1 sda1[0]
[>....................] recovery = 12.6% (37043392/293039104) finish=127.5min
md1 : active raid1 sdb1[0]
`
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte(mdstat), nil
})
if speed := getRebuildSpeed("/dev/md0"); speed != "" {
t.Fatalf("expected empty speed, got %s", speed)
}
}

View File

@@ -0,0 +1,578 @@
package resources
import (
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/models"
)
func TestFromNodeTemperatureAndCluster(t *testing.T) {
now := time.Now()
node := models.Node{
ID: "node-1",
Name: "node-1",
Instance: "pve1",
Status: "online",
CPU: 0.5,
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
Disk: models.Disk{Total: 200, Used: 100, Free: 100, Usage: 50},
Uptime: 10,
LastSeen: now,
IsClusterMember: true,
ClusterName: "cluster-a",
Temperature: &models.Temperature{
Available: true,
HasCPU: true,
CPUPackage: 72.5,
},
}
r := FromNode(node)
if r.Temperature == nil || *r.Temperature != 72.5 {
t.Fatalf("expected CPU package temperature")
}
if r.ClusterID != "pve-cluster/cluster-a" {
t.Fatalf("expected cluster ID")
}
node2 := node
node2.ID = "node-2"
node2.ClusterName = ""
node2.IsClusterMember = false
node2.Temperature = &models.Temperature{
Available: true,
HasCPU: true,
Cores: []models.CoreTemp{
{Core: 0, Temp: 60},
{Core: 1, Temp: 80},
},
}
r2 := FromNode(node2)
if r2.Temperature == nil || *r2.Temperature != 70 {
t.Fatalf("expected averaged core temperature")
}
if r2.ClusterID != "" {
t.Fatalf("expected empty cluster ID when not a member")
}
}
func TestFromHostAndDockerHost(t *testing.T) {
now := time.Now()
host := models.Host{
ID: "host-1",
Hostname: "host-1",
Status: "degraded",
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
Disks: []models.Disk{
{Total: 100, Used: 60, Free: 40, Usage: 60},
},
NetworkInterfaces: []models.HostNetworkInterface{
{Addresses: []string{"10.0.0.1"}, RXBytes: 10, TXBytes: 20},
},
Sensors: models.HostSensorSummary{
TemperatureCelsius: map[string]float64{"cpu": 55},
},
LastSeen: now,
}
hr := FromHost(host)
if hr.Temperature == nil || *hr.Temperature != 55 {
t.Fatalf("expected host temperature")
}
if hr.Network == nil || hr.Network.RXBytes != 10 || hr.Network.TXBytes != 20 {
t.Fatalf("expected host network totals")
}
if hr.Disk == nil || hr.Disk.Current == 0 {
t.Fatalf("expected host disk metrics")
}
if hr.Status != StatusDegraded {
t.Fatalf("expected degraded host status")
}
dockerHost := models.DockerHost{
ID: "docker-1",
AgentID: "agent-1",
Hostname: "docker-1",
DisplayName: "docker-1",
CustomDisplayName: "custom",
Status: "offline",
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
Disks: []models.Disk{
{Total: 100, Used: 50, Free: 50, Usage: 50},
},
NetworkInterfaces: []models.HostNetworkInterface{
{Addresses: []string{"10.0.0.2"}, RXBytes: 1, TXBytes: 2},
},
Swarm: &models.DockerSwarmInfo{
ClusterID: "swarm-1",
},
LastSeen: now,
}
dr := FromDockerHost(dockerHost)
if dr.DisplayName != "custom" {
t.Fatalf("expected custom display name")
}
if dr.ClusterID != "docker-swarm/swarm-1" {
t.Fatalf("expected docker swarm cluster ID")
}
if dr.Status != StatusOffline {
t.Fatalf("expected docker host status offline")
}
if dr.Identity == nil || len(dr.Identity.IPs) != 1 || dr.Identity.IPs[0] != "10.0.0.2" {
t.Fatalf("expected docker host identity IPs")
}
}
func TestFromHostPlatformData(t *testing.T) {
now := time.Now()
host := models.Host{
ID: "host-2",
Hostname: "host-2",
Status: "online",
DiskIO: []models.DiskIO{
{Device: "sda", ReadBytes: 100, WriteBytes: 200, ReadOps: 1, WriteOps: 2, ReadTime: 3, WriteTime: 4, IOTime: 5},
},
RAID: []models.HostRAIDArray{
{
Device: "md0",
Level: "raid1",
State: "clean",
TotalDevices: 2,
ActiveDevices: 2,
WorkingDevices: 2,
FailedDevices: 0,
SpareDevices: 0,
Devices: []models.HostRAIDDevice{
{Device: "sda1", State: "active", Slot: 0},
},
RebuildPercent: 0,
},
},
LastSeen: now,
}
r := FromHost(host)
var pd HostPlatformData
if err := r.GetPlatformData(&pd); err != nil {
t.Fatalf("failed to get platform data: %v", err)
}
if len(pd.DiskIO) != 1 || len(pd.RAID) != 1 {
t.Fatalf("expected disk IO and RAID entries")
}
}
func TestFromDockerContainerPodman(t *testing.T) {
now := time.Now()
container := models.DockerContainer{
ID: "container-1",
Name: "app",
Image: "app:latest",
State: "paused",
Status: "Paused",
CPUPercent: 5,
MemoryUsage: 50,
MemoryLimit: 100,
MemoryPercent: 50,
UptimeSeconds: 10,
Ports: []models.DockerContainerPort{
{PrivatePort: 80, PublicPort: 8080, Protocol: "tcp", IP: "0.0.0.0"},
},
Networks: []models.DockerContainerNetworkLink{
{Name: "bridge", IPv4: "172.17.0.2"},
},
Podman: &models.DockerPodmanContainer{
PodName: "pod",
},
CreatedAt: now,
}
r := FromDockerContainer(container, "host-1", "host-1")
if r.Status != StatusPaused {
t.Fatalf("expected paused status")
}
if r.Memory == nil || r.Memory.Current != 50 {
t.Fatalf("expected container memory metrics")
}
if r.ID != "host-1/container-1" {
t.Fatalf("expected container resource ID to be host-1/container-1")
}
}
func TestFromVMWithBackup(t *testing.T) {
now := time.Now()
vm := models.VM{
ID: "vm-1",
VMID: 100,
Name: "vm-1",
Node: "node-1",
Instance: "pve1",
Status: "running",
CPU: 0.25,
LastBackup: now,
LastSeen: now,
}
r := FromVM(vm)
var pd VMPlatformData
if err := r.GetPlatformData(&pd); err != nil {
t.Fatalf("failed to get platform data: %v", err)
}
if pd.LastBackup == nil || !pd.LastBackup.Equal(now) {
t.Fatalf("expected last backup to be set")
}
}
func TestFromContainerOCI(t *testing.T) {
now := time.Now()
ct := models.Container{
ID: "ct-1",
VMID: 200,
Name: "ct-1",
Node: "node-1",
Instance: "pve1",
Status: "paused",
Type: "lxc",
IsOCI: true,
CPU: 0.1,
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
Disk: models.Disk{Total: 200, Used: 100, Free: 100, Usage: 50},
LastBackup: now,
LastSeen: now,
}
r := FromContainer(ct)
if r.Type != ResourceTypeOCIContainer {
t.Fatalf("expected OCI container type")
}
var pd ContainerPlatformData
if err := r.GetPlatformData(&pd); err != nil {
t.Fatalf("failed to get platform data: %v", err)
}
if !pd.IsOCI || pd.LastBackup == nil {
t.Fatalf("expected OCI platform data with backup")
}
}
func TestKubernetesConversions(t *testing.T) {
now := time.Now()
cluster := models.KubernetesCluster{
ID: "cluster-1",
AgentID: "agent-1",
CustomDisplayName: "custom",
Status: "online",
LastSeen: now,
Nodes: []models.KubernetesNode{
{Name: "node-1", Ready: true, Unschedulable: true},
},
Pods: []models.KubernetesPod{
{Name: "pod-1", Namespace: "default", Phase: "Succeeded", NodeName: "node-1"},
},
Deployments: []models.KubernetesDeployment{
{Name: "dep-1", Namespace: "default", DesiredReplicas: 1, AvailableReplicas: 1},
},
}
cr := FromKubernetesCluster(cluster)
if cr.Name != "custom" || cr.DisplayName != "custom" {
t.Fatalf("expected custom display name to be used")
}
node := cluster.Nodes[0]
nr := FromKubernetesNode(node, cluster)
if nr.Status != StatusDegraded {
t.Fatalf("expected unschedulable node to be degraded")
}
if !strings.Contains(nr.ID, "cluster-1/node/") {
t.Fatalf("expected node ID to include cluster and node")
}
pod := cluster.Pods[0]
pr := FromKubernetesPod(pod, cluster)
if pr.Status != StatusStopped {
t.Fatalf("expected succeeded pod to be stopped")
}
if !strings.Contains(pr.ParentID, "cluster-1/node/") {
t.Fatalf("expected pod parent to be node")
}
dep := cluster.Deployments[0]
dr := FromKubernetesDeployment(dep, cluster)
if dr.Status != StatusRunning {
t.Fatalf("expected deployment running")
}
emptyCluster := models.KubernetesCluster{
ID: "cluster-2",
AgentID: "agent-2",
Status: "offline",
LastSeen: now,
}
er := FromKubernetesCluster(emptyCluster)
if er.Name != "cluster-2" || er.DisplayName != "cluster-2" {
t.Fatalf("expected cluster name fallback to ID")
}
if er.Status != StatusOffline {
t.Fatalf("expected offline cluster status")
}
}
func TestKubernetesNodeNotReady(t *testing.T) {
now := time.Now()
cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now}
node := models.KubernetesNode{
UID: "node-uid",
Name: "node-offline",
Ready: false,
}
r := FromKubernetesNode(node, cluster)
if r.Status != StatusOffline {
t.Fatalf("expected offline node status")
}
if !strings.Contains(r.ID, "node-uid") {
t.Fatalf("expected node UID in resource ID")
}
}
func TestKubernetesPodParentCluster(t *testing.T) {
now := time.Now()
cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now}
pod := models.KubernetesPod{
UID: "pod-uid",
Name: "pod-1",
Namespace: "default",
Phase: "Pending",
}
r := FromKubernetesPod(pod, cluster)
if r.ParentID != cluster.ID {
t.Fatalf("expected pod parent to be cluster ID")
}
if !strings.Contains(r.ID, "pod-uid") {
t.Fatalf("expected pod UID in resource ID")
}
}
func TestKubernetesPodContainers(t *testing.T) {
now := time.Now()
cluster := models.KubernetesCluster{ID: "cluster-3", AgentID: "agent-3", LastSeen: now}
pod := models.KubernetesPod{
Name: "pod-2",
Namespace: "default",
NodeName: "node-2",
Phase: "Running",
Containers: []models.KubernetesPodContainer{
{Name: "c1", Image: "busybox", Ready: true, RestartCount: 1, State: "running"},
},
}
r := FromKubernetesPod(pod, cluster)
if r.Status != StatusRunning {
t.Fatalf("expected running pod status")
}
if !strings.Contains(r.ID, "cluster-3/pod/") {
t.Fatalf("expected pod ID to include cluster")
}
}
func TestKubernetesDeploymentStatuses(t *testing.T) {
now := time.Now()
cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now}
depStopped := FromKubernetesDeployment(models.KubernetesDeployment{
Name: "dep-stop",
Namespace: "default",
DesiredReplicas: 0,
}, cluster)
if depStopped.Status != StatusStopped {
t.Fatalf("expected stopped deployment")
}
depDegraded := FromKubernetesDeployment(models.KubernetesDeployment{
Name: "dep-degraded",
Namespace: "default",
DesiredReplicas: 3,
AvailableReplicas: 1,
}, cluster)
if depDegraded.Status != StatusDegraded {
t.Fatalf("expected degraded deployment")
}
depUnknown := FromKubernetesDeployment(models.KubernetesDeployment{
Name: "dep-unknown",
Namespace: "default",
DesiredReplicas: 2,
AvailableReplicas: 0,
}, cluster)
if depUnknown.Status != StatusUnknown {
t.Fatalf("expected unknown deployment")
}
}
func TestFromPBSInstanceAndStorage(t *testing.T) {
now := time.Now()
pbs := models.PBSInstance{
ID: "pbs-1",
Name: "pbs",
Host: "pbs.local",
Status: "online",
ConnectionHealth: "unhealthy",
CPU: 20,
Memory: 50,
MemoryTotal: 100,
MemoryUsed: 50,
Uptime: 10,
LastSeen: now,
}
pr := FromPBSInstance(pbs)
if pr.Status != StatusDegraded {
t.Fatalf("expected degraded status for unhealthy pbs")
}
pbs2 := pbs
pbs2.ID = "pbs-2"
pbs2.ConnectionHealth = "healthy"
pbs2.Status = "offline"
pr2 := FromPBSInstance(pbs2)
if pr2.Status != StatusOffline {
t.Fatalf("expected offline status for pbs")
}
storageOnline := FromStorage(models.Storage{
ID: "storage-1",
Name: "local",
Instance: "pve1",
Node: "node1",
Total: 100,
Used: 50,
Free: 50,
Usage: 50,
Active: true,
Enabled: true,
})
if storageOnline.Status != StatusOnline {
t.Fatalf("expected online storage")
}
storageStopped := FromStorage(models.Storage{
ID: "storage-2",
Name: "local",
Instance: "pve1",
Node: "node1",
Active: true,
Enabled: false,
})
if storageStopped.Status != StatusStopped {
t.Fatalf("expected stopped storage")
}
storageOffline := FromStorage(models.Storage{
ID: "storage-3",
Name: "local",
Instance: "pve1",
Node: "node1",
Active: false,
Enabled: true,
})
if storageOffline.Status != StatusOffline {
t.Fatalf("expected offline storage")
}
}
func TestStatusMappings(t *testing.T) {
if mapGuestStatus("running") != StatusRunning {
t.Fatalf("expected running guest status")
}
if mapGuestStatus("stopped") != StatusStopped {
t.Fatalf("expected stopped guest status")
}
if mapGuestStatus("paused") != StatusPaused {
t.Fatalf("expected paused guest status")
}
if mapGuestStatus("unknown") != StatusUnknown {
t.Fatalf("expected unknown guest status")
}
if mapHostStatus("online") != StatusOnline {
t.Fatalf("expected online host status")
}
if mapHostStatus("offline") != StatusOffline {
t.Fatalf("expected offline host status")
}
if mapHostStatus("degraded") != StatusDegraded {
t.Fatalf("expected degraded host status")
}
if mapHostStatus("unknown") != StatusUnknown {
t.Fatalf("expected unknown host status")
}
if mapDockerHostStatus("online") != StatusOnline {
t.Fatalf("expected online docker host status")
}
if mapDockerHostStatus("offline") != StatusOffline {
t.Fatalf("expected offline docker host status")
}
if mapDockerHostStatus("unknown") != StatusUnknown {
t.Fatalf("expected unknown docker host status")
}
if mapDockerContainerStatus("running") != StatusRunning {
t.Fatalf("expected running container status")
}
if mapDockerContainerStatus("exited") != StatusStopped {
t.Fatalf("expected exited container status")
}
if mapDockerContainerStatus("dead") != StatusStopped {
t.Fatalf("expected dead container status")
}
if mapDockerContainerStatus("paused") != StatusPaused {
t.Fatalf("expected paused container status")
}
if mapDockerContainerStatus("restarting") != StatusUnknown {
t.Fatalf("expected restarting container status unknown")
}
if mapDockerContainerStatus("created") != StatusUnknown {
t.Fatalf("expected created container status unknown")
}
if mapDockerContainerStatus("other") != StatusUnknown {
t.Fatalf("expected unknown container status")
}
if mapKubernetesClusterStatus(" online ") != StatusOnline {
t.Fatalf("expected online k8s cluster status")
}
if mapKubernetesClusterStatus("offline") != StatusOffline {
t.Fatalf("expected offline k8s cluster status")
}
if mapKubernetesClusterStatus("unknown") != StatusUnknown {
t.Fatalf("expected unknown k8s cluster status")
}
if mapKubernetesPodStatus("running") != StatusRunning {
t.Fatalf("expected running pod status")
}
if mapKubernetesPodStatus("succeeded") != StatusStopped {
t.Fatalf("expected succeeded pod status")
}
if mapKubernetesPodStatus("failed") != StatusStopped {
t.Fatalf("expected failed pod status")
}
if mapKubernetesPodStatus("pending") != StatusUnknown {
t.Fatalf("expected pending pod status")
}
if mapKubernetesPodStatus("unknown") != StatusUnknown {
t.Fatalf("expected unknown pod status")
}
if mapPBSStatus("online", "healthy") != StatusOnline {
t.Fatalf("expected online PBS status")
}
if mapPBSStatus("offline", "healthy") != StatusOffline {
t.Fatalf("expected offline PBS status")
}
if mapPBSStatus("other", "healthy") != StatusUnknown {
t.Fatalf("expected unknown PBS status")
}
if mapPBSStatus("online", "unhealthy") != StatusDegraded {
t.Fatalf("expected degraded PBS status")
}
}

View File

@@ -0,0 +1,834 @@
package resources
import (
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/models"
)
func TestResourcePlatformDataAndMetrics(t *testing.T) {
var out struct {
Value string `json:"value"`
}
r := Resource{Name: "name"}
if err := r.GetPlatformData(&out); err != nil {
t.Fatalf("expected nil error for empty platform data, got %v", err)
}
if err := r.SetPlatformData(make(chan int)); err == nil {
t.Fatal("expected error for unserializable platform data")
}
if err := r.SetPlatformData(struct {
Value string `json:"value"`
}{Value: "ok"}); err != nil {
t.Fatalf("unexpected error setting platform data: %v", err)
}
if err := r.GetPlatformData(&out); err != nil {
t.Fatalf("unexpected error getting platform data: %v", err)
}
if out.Value != "ok" {
t.Fatalf("expected platform data value to be ok, got %q", out.Value)
}
r.PlatformData = []byte("{")
if err := r.GetPlatformData(&out); err == nil {
t.Fatal("expected error for invalid platform data")
}
r.DisplayName = "display"
if r.EffectiveDisplayName() != "display" {
t.Fatalf("expected display name to be used")
}
r.DisplayName = ""
if r.EffectiveDisplayName() != "name" {
t.Fatalf("expected name fallback")
}
if r.CPUPercent() != 0 {
t.Fatalf("expected CPUPercent 0 for nil CPU")
}
r.CPU = &MetricValue{Current: 12.5}
if r.CPUPercent() != 12.5 {
t.Fatalf("expected CPUPercent 12.5")
}
if r.MemoryPercent() != 0 {
t.Fatalf("expected MemoryPercent 0 for nil memory")
}
total := int64(100)
used := int64(25)
r.Memory = &MetricValue{Total: &total, Used: &used}
if r.MemoryPercent() != 25 {
t.Fatalf("expected MemoryPercent 25")
}
r.Memory = &MetricValue{Current: 40}
if r.MemoryPercent() != 40 {
t.Fatalf("expected MemoryPercent 40")
}
if r.DiskPercent() != 0 {
t.Fatalf("expected DiskPercent 0 for nil disk")
}
totalDisk := int64(200)
usedDisk := int64(50)
r.Disk = &MetricValue{Total: &totalDisk, Used: &usedDisk}
if r.DiskPercent() != 25 {
t.Fatalf("expected DiskPercent 25")
}
r.Disk = &MetricValue{Current: 33}
if r.DiskPercent() != 33 {
t.Fatalf("expected DiskPercent 33")
}
r.Type = ResourceTypeNode
if !r.IsInfrastructure() {
t.Fatalf("expected node to be infrastructure")
}
if r.IsWorkload() {
t.Fatalf("expected node to not be workload")
}
r.Type = ResourceTypeVM
if r.IsInfrastructure() {
t.Fatalf("expected vm to not be infrastructure")
}
if !r.IsWorkload() {
t.Fatalf("expected vm to be workload")
}
}
func TestStorePreferredAndSuppressed(t *testing.T) {
store := NewStore()
now := time.Now()
existing := Resource{
ID: "agent-1",
Type: ResourceTypeHost,
SourceType: SourceAgent,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "host1",
},
}
store.Upsert(existing)
incoming := Resource{
ID: "api-1",
Type: ResourceTypeHost,
SourceType: SourceAPI,
LastSeen: now.Add(1 * time.Minute),
Identity: &ResourceIdentity{
Hostname: "host1",
},
}
preferred := store.Upsert(incoming)
if preferred != existing.ID {
t.Fatalf("expected preferred ID to be %s, got %s", existing.ID, preferred)
}
if !store.IsSuppressed(incoming.ID) {
t.Fatalf("expected incoming to be suppressed")
}
if store.GetPreferredID(incoming.ID) != existing.ID {
t.Fatalf("expected preferred ID to map to %s", existing.ID)
}
if store.GetPreferredID(existing.ID) != existing.ID {
t.Fatalf("expected preferred ID to return itself")
}
got, ok := store.Get(incoming.ID)
if !ok || got.ID != existing.ID {
t.Fatalf("expected Get to return preferred resource")
}
if store.GetPreferredResourceFor(incoming.ID) == nil {
t.Fatalf("expected preferred resource for suppressed ID")
}
if store.GetPreferredResourceFor(existing.ID) == nil {
t.Fatalf("expected preferred resource for existing ID")
}
if store.GetPreferredResourceFor("missing") != nil {
t.Fatalf("expected nil for missing resource")
}
if !store.IsSamePhysicalMachine(existing.ID, incoming.ID) {
t.Fatalf("expected IDs to be same physical machine")
}
if !store.IsSamePhysicalMachine(incoming.ID, existing.ID) {
t.Fatalf("expected merged ID to match preferred")
}
if store.IsSamePhysicalMachine(existing.ID, "other") {
t.Fatalf("expected different IDs to not match")
}
if !store.IsSamePhysicalMachine(existing.ID, existing.ID) {
t.Fatalf("expected same ID to match")
}
store.Remove(existing.ID)
if store.IsSuppressed(incoming.ID) {
t.Fatalf("expected suppression to be cleared after removal")
}
if _, ok := store.Get(incoming.ID); ok {
t.Fatalf("expected Get to fail for removed preferred resource")
}
}
func TestStoreStatsAndHelpers(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "host-1",
Type: ResourceTypeHost,
Status: StatusOffline,
SourceType: SourceAgent,
LastSeen: now,
Alerts: []ResourceAlert{{ID: "a1"}},
Identity: &ResourceIdentity{
Hostname: "host-1",
},
})
store.Upsert(Resource{
ID: "host-2",
Type: ResourceTypeHost,
Status: StatusOnline,
SourceType: SourceAPI,
LastSeen: now.Add(time.Minute),
Identity: &ResourceIdentity{
Hostname: "host-1",
},
})
stats := store.GetStats()
if stats.SuppressedResources != 1 {
t.Fatalf("expected 1 suppressed resource")
}
if stats.WithAlerts != 1 {
t.Fatalf("expected 1 resource with alerts")
}
if store.sourceScore(SourceType("other")) != 0 {
t.Fatalf("expected default source score")
}
a := &Resource{ID: "a", SourceType: SourceAPI, LastSeen: now}
b := &Resource{ID: "b", SourceType: SourceAgent, LastSeen: now.Add(time.Second)}
if store.preferredResource(a, b) != b {
t.Fatalf("expected agent resource to be preferred")
}
c := &Resource{ID: "c", SourceType: SourceAPI, LastSeen: now.Add(time.Second)}
if store.preferredResource(a, c) != c {
t.Fatalf("expected newer resource to be preferred")
}
d := &Resource{ID: "d", SourceType: SourceAgent, LastSeen: now}
e := &Resource{ID: "e", SourceType: SourceAPI, LastSeen: now}
if store.preferredResource(d, e) != d {
t.Fatalf("expected higher score resource to be preferred")
}
f := &Resource{ID: "f", SourceType: SourceAPI, LastSeen: now.Add(2 * time.Second)}
g := &Resource{ID: "g", SourceType: SourceAPI, LastSeen: now.Add(1 * time.Second)}
if store.preferredResource(f, g) != f {
t.Fatalf("expected newer resource to be preferred when scores equal")
}
if store.findDuplicate(&Resource{}) != "" {
t.Fatalf("expected no duplicate for nil identity")
}
if store.findDuplicate(&Resource{Type: ResourceTypeVM, Identity: &ResourceIdentity{Hostname: "host-1"}}) != "" {
t.Fatalf("expected no duplicate for workload type")
}
}
func TestStorePreferredSourceAndPolling(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "api-1",
Type: ResourceTypeHost,
SourceType: SourceAPI,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "host1",
},
})
if store.HasPreferredSourceForHostname("host1") {
t.Fatalf("expected no preferred source for API-only hostname")
}
store.Upsert(Resource{
ID: "hybrid-1",
Type: ResourceTypeHost,
SourceType: SourceHybrid,
LastSeen: now.Add(time.Second),
Identity: &ResourceIdentity{
Hostname: "host1",
},
})
if !store.HasPreferredSourceForHostname("HOST1") {
t.Fatalf("expected preferred source for hostname")
}
if store.HasPreferredSourceForHostname("") {
t.Fatalf("expected empty hostname to be false")
}
if !store.ShouldSkipAPIPolling("host1") {
t.Fatalf("expected skip polling for preferred source")
}
if store.HasPreferredSourceForHostname("missing") {
t.Fatalf("expected missing hostname to be false")
}
}
func TestStoreAgentHostnamesAndRecommendations(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "agent-1",
Type: ResourceTypeHost,
SourceType: SourceAgent,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "host1",
},
})
store.Upsert(Resource{
ID: "hybrid-1",
Type: ResourceTypeHost,
SourceType: SourceHybrid,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "HOST2",
},
})
store.Upsert(Resource{
ID: "api-1",
Type: ResourceTypeHost,
SourceType: SourceAPI,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "host3",
},
})
store.Upsert(Resource{
ID: "no-identity",
Type: ResourceTypeHost,
SourceType: SourceAgent,
LastSeen: now,
})
store.Upsert(Resource{
ID: "node-agent",
Type: ResourceTypeNode,
SourceType: SourceAgent,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "host1",
},
})
hostnames := store.GetAgentMonitoredHostnames()
seen := make(map[string]bool)
for _, h := range hostnames {
seen[strings.ToLower(h)] = true
}
if !seen["host1"] || !seen["host2"] {
t.Fatalf("expected host1 and host2 to be monitored, got %v", hostnames)
}
if len(seen) != 2 {
t.Fatalf("expected two unique hostnames")
}
recs := store.GetPollingRecommendations()
if recs["host1"] != 0 {
t.Fatalf("expected host1 recommendation 0, got %v", recs["host1"])
}
if recs["host2"] != 0.5 {
t.Fatalf("expected host2 recommendation 0.5, got %v", recs["host2"])
}
if _, ok := recs["host3"]; ok {
t.Fatalf("did not expect API-only host to have recommendation")
}
}
func TestStoreFindContainerHost(t *testing.T) {
store := NewStore()
now := time.Now()
host := Resource{
ID: "docker-host-1",
Type: ResourceTypeDockerHost,
Name: "docker-host",
SourceType: SourceAgent,
LastSeen: now,
Identity: &ResourceIdentity{
Hostname: "docker1",
},
}
store.Upsert(host)
store.Upsert(Resource{
ID: "docker-host-1/container-1",
Type: ResourceTypeDockerContainer,
Name: "web-app",
ParentID: "docker-host-1",
SourceType: SourceAgent,
LastSeen: now,
})
if store.FindContainerHost("") != "" {
t.Fatalf("expected empty search to return empty")
}
if store.FindContainerHost("missing") != "" {
t.Fatalf("expected missing container to return empty")
}
if store.FindContainerHost("web-app") != "docker1" {
t.Fatalf("expected host name from identity")
}
if store.FindContainerHost("CONTAINER-1") != "docker1" {
t.Fatalf("expected match by ID")
}
if store.FindContainerHost("web") != "docker1" {
t.Fatalf("expected match by substring")
}
store2 := NewStore()
store2.Upsert(Resource{
ID: "container-2",
Type: ResourceTypeDockerContainer,
Name: "db",
ParentID: "missing-host",
SourceType: SourceAgent,
LastSeen: now,
})
if store2.FindContainerHost("db") != "" {
t.Fatalf("expected missing parent to return empty")
}
store3 := NewStore()
store3.Upsert(Resource{
ID: "host-no-identity",
Type: ResourceTypeDockerHost,
Name: "host-name",
SourceType: SourceAgent,
LastSeen: now,
})
store3.Upsert(Resource{
ID: "host-no-identity/container-3",
Type: ResourceTypeDockerContainer,
Name: "cache",
ParentID: "host-no-identity",
SourceType: SourceAgent,
LastSeen: now,
})
if store3.FindContainerHost("cache") != "host-name" {
t.Fatalf("expected fallback to host name")
}
}
func TestResourceQueryFiltersAndSorting(t *testing.T) {
store := NewStore()
base := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
store.Upsert(Resource{
ID: "r1",
Type: ResourceTypeVM,
Name: "b",
Status: StatusRunning,
ClusterID: "c1",
CPU: &MetricValue{Current: 10},
Memory: &MetricValue{Current: 60},
Disk: &MetricValue{Current: 20},
LastSeen: base.Add(1 * time.Hour),
Alerts: []ResourceAlert{{ID: "a1"}},
SourceType: SourceAPI,
})
store.Upsert(Resource{
ID: "r2",
Type: ResourceTypeNode,
Name: "a",
Status: StatusOnline,
ClusterID: "c1",
CPU: &MetricValue{Current: 50},
Memory: &MetricValue{Current: 10},
Disk: &MetricValue{Current: 90},
LastSeen: base.Add(2 * time.Hour),
SourceType: SourceAPI,
})
store.Upsert(Resource{
ID: "r3",
Type: ResourceTypeVM,
Name: "c",
Status: StatusOffline,
ClusterID: "c2",
CPU: &MetricValue{Current: 5},
Memory: &MetricValue{Current: 30},
Disk: &MetricValue{Current: 10},
LastSeen: base.Add(3 * time.Hour),
SourceType: SourceAPI,
})
clustered := store.Query().InCluster("c1").Execute()
if len(clustered) != 2 {
t.Fatalf("expected 2 clustered resources, got %d", len(clustered))
}
withAlerts := store.Query().WithAlerts().Execute()
if len(withAlerts) != 1 || withAlerts[0].ID != "r1" {
t.Fatalf("expected only r1 with alerts")
}
sorted := store.Query().SortBy("name", false).Execute()
if len(sorted) < 2 || sorted[0].Name != "a" {
t.Fatalf("expected sorted results by name")
}
limited := store.Query().SortBy("cpu", true).Offset(1).Limit(1).Execute()
if len(limited) != 1 {
t.Fatalf("expected limited results")
}
empty := store.Query().Offset(10).Execute()
if len(empty) != 0 {
t.Fatalf("expected empty results for large offset")
}
}
func TestSortResourcesFields(t *testing.T) {
base := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
resources := []Resource{
{
ID: "r1",
Type: ResourceTypeVM,
Name: "b",
Status: StatusRunning,
CPU: &MetricValue{Current: 20},
Memory: &MetricValue{Current: 40},
Disk: &MetricValue{Current: 30},
LastSeen: base.Add(1 * time.Hour),
},
{
ID: "r2",
Type: ResourceTypeNode,
Name: "a",
Status: StatusOffline,
CPU: &MetricValue{Current: 50},
Memory: &MetricValue{Current: 10},
Disk: &MetricValue{Current: 90},
LastSeen: base.Add(2 * time.Hour),
},
{
ID: "r3",
Type: ResourceTypeContainer,
Name: "c",
Status: StatusDegraded,
CPU: &MetricValue{Current: 5},
Memory: &MetricValue{Current: 80},
Disk: &MetricValue{Current: 10},
LastSeen: base.Add(3 * time.Hour),
},
}
single := []Resource{{ID: "only"}}
sortResources(single, "name", false)
cases := []struct {
field string
desc bool
want string
}{
{"name", false, "r2"},
{"name", true, "r3"},
{"type", false, "r3"},
{"type", true, "r1"},
{"status", false, "r3"},
{"status", true, "r1"},
{"cpu", true, "r2"},
{"cpu", false, "r3"},
{"memory", true, "r3"},
{"memory", false, "r2"},
{"disk", true, "r2"},
{"disk", false, "r3"},
{"last_seen", true, "r3"},
{"last_seen", false, "r1"},
{"lastseen", true, "r3"},
{"mem", true, "r3"},
}
for _, tc := range cases {
sorted := append([]Resource(nil), resources...)
sortResources(sorted, tc.field, tc.desc)
if sorted[0].ID != tc.want {
t.Fatalf("sort %s desc=%v expected %s, got %s", tc.field, tc.desc, tc.want, sorted[0].ID)
}
}
}
func TestGetTopByMemoryAndDisk(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "vm1",
Type: ResourceTypeVM,
Memory: &MetricValue{Current: 80},
Disk: &MetricValue{Current: 30},
LastSeen: now,
})
store.Upsert(Resource{
ID: "vm2",
Type: ResourceTypeVM,
Memory: &MetricValue{Current: 20},
Disk: &MetricValue{Current: 90},
LastSeen: now,
})
store.Upsert(Resource{
ID: "node1",
Type: ResourceTypeNode,
Memory: &MetricValue{Current: 60},
Disk: &MetricValue{Current: 10},
LastSeen: now,
})
store.Upsert(Resource{
ID: "skip-memory",
Type: ResourceTypeVM,
LastSeen: now,
})
store.Upsert(Resource{
ID: "skip-disk",
Type: ResourceTypeVM,
Disk: &MetricValue{Current: 0},
LastSeen: now,
})
topMem := store.GetTopByMemory(1, nil)
if len(topMem) != 1 || topMem[0].ID != "vm1" {
t.Fatalf("expected vm1 to be top memory")
}
topMemVMs := store.GetTopByMemory(10, []ResourceType{ResourceTypeVM})
if len(topMemVMs) != 2 {
t.Fatalf("expected 2 VM memory results, got %d", len(topMemVMs))
}
topDisk := store.GetTopByDisk(1, nil)
if len(topDisk) != 1 || topDisk[0].ID != "vm2" {
t.Fatalf("expected vm2 to be top disk")
}
topDiskNodes := store.GetTopByDisk(10, []ResourceType{ResourceTypeNode})
if len(topDiskNodes) != 1 || topDiskNodes[0].ID != "node1" {
t.Fatalf("expected node1 to be top disk node")
}
}
func TestGetTopByCPU_SkipsZero(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "skip-cpu-1",
Type: ResourceTypeVM,
LastSeen: now,
})
store.Upsert(Resource{
ID: "skip-cpu-2",
Type: ResourceTypeVM,
CPU: &MetricValue{Current: 0},
LastSeen: now,
})
store.Upsert(Resource{
ID: "cpu-1",
Type: ResourceTypeVM,
CPU: &MetricValue{Current: 10},
LastSeen: now,
})
top := store.GetTopByCPU(10, nil)
if len(top) != 1 || top[0].ID != "cpu-1" {
t.Fatalf("expected only cpu-1 to be returned")
}
}
func TestGetRelatedWithChildren(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "parent",
Type: ResourceTypeNode,
Name: "parent",
LastSeen: now,
})
store.Upsert(Resource{
ID: "child-1",
Type: ResourceTypeVM,
Name: "child-1",
ParentID: "parent",
LastSeen: now,
})
related := store.GetRelated("parent")
if children, ok := related["children"]; !ok || len(children) != 1 {
t.Fatalf("expected children for parent")
}
if len(store.GetRelated("missing")) != 0 {
t.Fatalf("expected no related resources for missing ID")
}
}
func TestResourceSummaryWithDegradedAndAlerts(t *testing.T) {
store := NewStore()
now := time.Now()
store.Upsert(Resource{
ID: "healthy",
Type: ResourceTypeNode,
Status: StatusOnline,
LastSeen: now,
CPU: &MetricValue{Current: 20},
Memory: &MetricValue{Current: 50},
})
store.Upsert(Resource{
ID: "degraded",
Type: ResourceTypeVM,
Status: StatusDegraded,
LastSeen: now,
Alerts: []ResourceAlert{{ID: "a1"}},
})
store.Upsert(Resource{
ID: "unknown",
Type: ResourceTypeVM,
Status: StatusUnknown,
LastSeen: now,
})
summary := store.GetResourceSummary()
if summary.Degraded != 1 {
t.Fatalf("expected degraded count 1")
}
if summary.Offline != 1 {
t.Fatalf("expected offline count 1")
}
if summary.WithAlerts != 1 {
t.Fatalf("expected alerts count 1")
}
}
func TestPopulateFromSnapshotFull(t *testing.T) {
store := NewStore()
store.Upsert(Resource{ID: "old-resource", Type: ResourceTypeNode, LastSeen: time.Now()})
now := time.Now()
cluster := models.KubernetesCluster{
ID: "cluster-1",
AgentID: "agent-1",
Status: "online",
LastSeen: now,
Nodes: []models.KubernetesNode{
{Name: "node-1", Ready: true},
},
Pods: []models.KubernetesPod{
{Name: "pod-1", Namespace: "default", Phase: "Running", NodeName: "node-1"},
},
Deployments: []models.KubernetesDeployment{
{Name: "dep-1", Namespace: "default", DesiredReplicas: 1, AvailableReplicas: 1},
},
}
dockerHost := models.DockerHost{
ID: "docker-1",
AgentID: "agent-docker",
Hostname: "docker-host",
Status: "online",
Memory: models.Memory{
Total: 100,
Used: 50,
Free: 50,
Usage: 50,
},
Disks: []models.Disk{
{Total: 100, Used: 60, Free: 40, Usage: 60},
},
NetworkInterfaces: []models.HostNetworkInterface{
{Addresses: []string{"10.0.0.1"}, RXBytes: 1, TXBytes: 2},
},
Containers: []models.DockerContainer{
{
ID: "container-1",
Name: "web",
State: "running",
Status: "Up",
CPUPercent: 5,
MemoryUsage: 50,
MemoryLimit: 100,
MemoryPercent: 50,
},
},
LastSeen: now,
}
snapshot := models.StateSnapshot{
DockerHosts: []models.DockerHost{dockerHost},
KubernetesClusters: []models.KubernetesCluster{
cluster,
},
PBSInstances: []models.PBSInstance{
{
ID: "pbs-1",
Name: "pbs",
Host: "pbs.local",
Status: "online",
ConnectionHealth: "healthy",
CPU: 20,
Memory: 50,
MemoryTotal: 100,
MemoryUsed: 50,
Uptime: 10,
LastSeen: now,
},
},
Storage: []models.Storage{
{
ID: "storage-1",
Name: "local",
Instance: "pve1",
Node: "node1",
Total: 100,
Used: 50,
Free: 50,
Usage: 50,
Active: true,
Enabled: true,
},
},
}
store.PopulateFromSnapshot(snapshot)
if _, ok := store.Get("old-resource"); ok {
t.Fatalf("expected old resource to be removed")
}
if len(store.Query().OfType(ResourceTypeDockerHost).Execute()) != 1 {
t.Fatalf("expected 1 docker host")
}
if len(store.Query().OfType(ResourceTypeDockerContainer).Execute()) != 1 {
t.Fatalf("expected 1 docker container")
}
if len(store.Query().OfType(ResourceTypeK8sCluster).Execute()) != 1 {
t.Fatalf("expected 1 k8s cluster")
}
if len(store.Query().OfType(ResourceTypeK8sNode).Execute()) != 1 {
t.Fatalf("expected 1 k8s node")
}
if len(store.Query().OfType(ResourceTypePod).Execute()) != 1 {
t.Fatalf("expected 1 k8s pod")
}
if len(store.Query().OfType(ResourceTypeK8sDeployment).Execute()) != 1 {
t.Fatalf("expected 1 k8s deployment")
}
if len(store.Query().OfType(ResourceTypePBS).Execute()) != 1 {
t.Fatalf("expected 1 PBS instance")
}
if len(store.Query().OfType(ResourceTypeStorage).Execute()) != 1 {
t.Fatalf("expected 1 storage resource")
}
}

View File

@@ -13,17 +13,25 @@ import (
"github.com/rs/zerolog/log"
)
var (
execLookPath = exec.LookPath
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
return exec.CommandContext(ctx, name, args...).Output()
}
timeNow = time.Now
)
// DiskSMART represents S.M.A.R.T. data for a single disk.
type DiskSMART struct {
Device string `json:"device"` // Device path (e.g., /dev/sda)
Model string `json:"model,omitempty"` // Disk model
Serial string `json:"serial,omitempty"` // Serial number
WWN string `json:"wwn,omitempty"` // World Wide Name
Type string `json:"type,omitempty"` // Transport type: sata, sas, nvme
Temperature int `json:"temperature"` // Temperature in Celsius
Health string `json:"health,omitempty"` // PASSED, FAILED, UNKNOWN
Standby bool `json:"standby,omitempty"` // True if disk was in standby
LastUpdated time.Time `json:"lastUpdated"` // When this reading was taken
Device string `json:"device"` // Device path (e.g., /dev/sda)
Model string `json:"model,omitempty"` // Disk model
Serial string `json:"serial,omitempty"` // Serial number
WWN string `json:"wwn,omitempty"` // World Wide Name
Type string `json:"type,omitempty"` // Transport type: sata, sas, nvme
Temperature int `json:"temperature"` // Temperature in Celsius
Health string `json:"health,omitempty"` // PASSED, FAILED, UNKNOWN
Standby bool `json:"standby,omitempty"` // True if disk was in standby
LastUpdated time.Time `json:"lastUpdated"` // When this reading was taken
}
// smartctlJSON represents the JSON output from smartctl --json.
@@ -85,8 +93,7 @@ func CollectLocal(ctx context.Context) ([]DiskSMART, error) {
// listBlockDevices returns a list of block devices suitable for SMART queries.
func listBlockDevices(ctx context.Context) ([]string, error) {
// Use lsblk to find disks (not partitions)
cmd := exec.CommandContext(ctx, "lsblk", "-d", "-n", "-o", "NAME,TYPE")
output, err := cmd.Output()
output, err := runCommandOutput(ctx, "lsblk", "-d", "-n", "-o", "NAME,TYPE")
if err != nil {
return nil, err
}
@@ -114,7 +121,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
defer cancel()
// Check if smartctl is available
smartctlPath, err := exec.LookPath("smartctl")
smartctlPath, err := execLookPath("smartctl")
if err != nil {
return nil, err
}
@@ -124,8 +131,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
// -i: device info
// -A: attributes (for temperature)
// --json=o: output original smartctl JSON format
cmd := exec.CommandContext(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device)
output, err := cmd.Output()
output, err := runCommandOutput(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device)
// smartctl returns non-zero exit codes for various conditions
// Exit code 2 means drive is in standby - that's okay
@@ -137,7 +143,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
return &DiskSMART{
Device: filepath.Base(device),
Standby: true,
LastUpdated: time.Now(),
LastUpdated: timeNow(),
}, nil
}
// Other exit codes might still have valid JSON output
@@ -161,7 +167,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
Model: smartData.ModelName,
Serial: smartData.SerialNumber,
Type: detectDiskType(smartData),
LastUpdated: time.Now(),
LastUpdated: timeNow(),
}
// Build WWN string if available

View File

@@ -0,0 +1,303 @@
package smartctl
import (
"context"
"encoding/json"
"errors"
"os/exec"
"strings"
"testing"
"time"
)
func TestListBlockDevices(t *testing.T) {
origRun := runCommandOutput
t.Cleanup(func() { runCommandOutput = origRun })
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
if name != "lsblk" {
return nil, errors.New("unexpected command")
}
out := "sda disk\nsda1 part\nsr0 rom\nnvme0n1 disk\n\n"
return []byte(out), nil
}
devices, err := listBlockDevices(context.Background())
if err != nil {
t.Fatalf("listBlockDevices error: %v", err)
}
if len(devices) != 2 || devices[0] != "/dev/sda" || devices[1] != "/dev/nvme0n1" {
t.Fatalf("unexpected devices: %#v", devices)
}
}
func TestListBlockDevicesError(t *testing.T) {
origRun := runCommandOutput
t.Cleanup(func() { runCommandOutput = origRun })
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("boom")
}
if _, err := listBlockDevices(context.Background()); err == nil {
t.Fatalf("expected error")
}
}
func TestDefaultRunCommandOutput(t *testing.T) {
origRun := runCommandOutput
t.Cleanup(func() { runCommandOutput = origRun })
runCommandOutput = origRun
if _, err := runCommandOutput(context.Background(), "sh", "-c", "printf ''"); err != nil {
t.Fatalf("expected default runCommandOutput to work, got %v", err)
}
}
func TestCollectLocalNoDevices(t *testing.T) {
origRun := runCommandOutput
t.Cleanup(func() { runCommandOutput = origRun })
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
if name == "lsblk" {
return []byte(""), nil
}
return nil, errors.New("unexpected command")
}
result, err := CollectLocal(context.Background())
if err != nil {
t.Fatalf("CollectLocal error: %v", err)
}
if result != nil {
t.Fatalf("expected nil result for no devices")
}
}
func TestCollectLocalListDevicesError(t *testing.T) {
origRun := runCommandOutput
t.Cleanup(func() { runCommandOutput = origRun })
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
return nil, errors.New("lsblk failed")
}
if _, err := CollectLocal(context.Background()); err == nil {
t.Fatalf("expected list error")
}
}
func TestCollectLocalSkipsErrors(t *testing.T) {
origRun := runCommandOutput
origLook := execLookPath
origNow := timeNow
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
timeNow = origNow
})
fixed := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
timeNow = func() time.Time { return fixed }
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
if name == "lsblk" {
return []byte("sda disk\nsdb disk\n"), nil
}
if name == "smartctl" {
device := args[len(args)-1]
if strings.Contains(device, "sda") {
return nil, errors.New("read error")
}
payload := smartctlJSON{
ModelName: "Model",
SerialNumber: "Serial",
}
payload.Device.Protocol = "ATA"
payload.SmartStatus.Passed = true
payload.Temperature.Current = 30
out, _ := json.Marshal(payload)
return out, nil
}
return nil, errors.New("unexpected command")
}
result, err := CollectLocal(context.Background())
if err != nil {
t.Fatalf("CollectLocal error: %v", err)
}
if len(result) != 1 {
t.Fatalf("expected 1 result, got %d", len(result))
}
if result[0].Device != "sdb" || !result[0].LastUpdated.Equal(fixed) {
t.Fatalf("unexpected result: %#v", result[0])
}
}
func TestCollectDeviceSMARTLookPathError(t *testing.T) {
origLook := execLookPath
t.Cleanup(func() { execLookPath = origLook })
execLookPath = func(string) (string, error) { return "", errors.New("missing") }
if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil {
t.Fatalf("expected lookpath error")
}
}
func TestCollectDeviceSMARTStandby(t *testing.T) {
if _, err := exec.LookPath("sh"); err != nil {
t.Skip("sh not available")
}
origRun := runCommandOutput
origLook := execLookPath
origNow := timeNow
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
timeNow = origNow
})
fixed := time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC)
timeNow = func() time.Time { return fixed }
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
return exec.CommandContext(ctx, "sh", "-c", "exit 2").Output()
}
result, err := collectDeviceSMART(context.Background(), "/dev/sda")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == nil || !result.Standby || result.Device != "sda" || !result.LastUpdated.Equal(fixed) {
t.Fatalf("unexpected standby result: %#v", result)
}
}
func TestCollectDeviceSMARTExitErrorNoOutput(t *testing.T) {
if _, err := exec.LookPath("sh"); err != nil {
t.Skip("sh not available")
}
origRun := runCommandOutput
origLook := execLookPath
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
})
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
return exec.CommandContext(ctx, "sh", "-c", "exit 1").Output()
}
if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil {
t.Fatalf("expected error for exit code without output")
}
}
func TestCollectDeviceSMARTExitErrorWithOutput(t *testing.T) {
if _, err := exec.LookPath("sh"); err != nil {
t.Skip("sh not available")
}
origRun := runCommandOutput
origLook := execLookPath
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
})
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
payload := `{"model_name":"Model","serial_number":"Serial","device":{"protocol":"ATA"},"smart_status":{"passed":false},"temperature":{"current":45}}`
return exec.CommandContext(ctx, "sh", "-c", "echo '"+payload+"'; exit 1").Output()
}
result, err := collectDeviceSMART(context.Background(), "/dev/sda")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == nil || result.Health != "FAILED" || result.Temperature != 45 {
t.Fatalf("unexpected result: %#v", result)
}
}
func TestCollectDeviceSMARTJSONError(t *testing.T) {
origRun := runCommandOutput
origLook := execLookPath
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
})
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
return []byte("{"), nil
}
if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil {
t.Fatalf("expected json error")
}
}
func TestCollectDeviceSMARTNVMeTempFallback(t *testing.T) {
origRun := runCommandOutput
origLook := execLookPath
origNow := timeNow
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
timeNow = origNow
})
fixed := time.Date(2024, 4, 5, 6, 7, 8, 0, time.UTC)
timeNow = func() time.Time { return fixed }
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
payload := smartctlJSON{}
payload.Device.Protocol = "NVMe"
payload.NVMeSmartHealthInformationLog.Temperature = 55
payload.SmartStatus.Passed = true
out, _ := json.Marshal(payload)
return out, nil
}
result, err := collectDeviceSMART(context.Background(), "/dev/nvme0n1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Temperature != 55 || result.Type != "nvme" || result.Health != "PASSED" || !result.LastUpdated.Equal(fixed) {
t.Fatalf("unexpected result: %#v", result)
}
}
func TestCollectDeviceSMARTWWN(t *testing.T) {
origRun := runCommandOutput
origLook := execLookPath
t.Cleanup(func() {
runCommandOutput = origRun
execLookPath = origLook
})
execLookPath = func(string) (string, error) { return "smartctl", nil }
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
payload := smartctlJSON{}
payload.WWN.NAA = 5
payload.WWN.OUI = 0xabc
payload.WWN.ID = 0x1234
payload.Device.Protocol = "SAS"
payload.SmartStatus.Passed = true
out, _ := json.Marshal(payload)
return out, nil
}
result, err := collectDeviceSMART(context.Background(), "/dev/sda")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.WWN != "5-abc-1234" {
t.Fatalf("unexpected WWN: %q", result.WWN)
}
}

View File

@@ -6,6 +6,7 @@ import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
@@ -44,12 +45,33 @@ const (
)
var (
mkdirAllFn = os.MkdirAll
statFn = os.Stat
openFileFn = os.OpenFile
openFn = os.Open
appendOpenFileFn = func(path string) (io.WriteCloser, error) {
return openFileFn(path, os.O_APPEND|os.O_WRONLY, 0o600)
}
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, "ssh-keyscan", args...)
return cmd.CombinedOutput()
}
// ErrNoHostKeys is returned when ssh-keyscan yields no usable entries.
ErrNoHostKeys = errors.New("knownhosts: no host keys discovered")
// ErrHostKeyChanged signals that a host key already exists with a different fingerprint.
ErrHostKeyChanged = errors.New("knownhosts: host key changed")
)
var (
defaultMkdirAllFn = mkdirAllFn
defaultStatFn = statFn
defaultOpenFileFn = openFileFn
defaultOpenFn = openFn
defaultAppendOpenFileFn = appendOpenFileFn
defaultKeyscanCmdRunner = keyscanCmdRunner
)
// HostKeyChangeError describes a detected host key mismatch.
type HostKeyChangeError struct {
Host string
@@ -214,17 +236,17 @@ func (m *manager) Path() string {
func (m *manager) ensureKnownHostsFile() error {
dir := filepath.Dir(m.path)
if err := os.MkdirAll(dir, 0o700); err != nil {
if err := mkdirAllFn(dir, 0o700); err != nil {
return fmt.Errorf("knownhosts: mkdir %s: %w", dir, err)
}
if _, err := os.Stat(m.path); err == nil {
if _, err := statFn(m.path); err == nil {
return nil
} else if !os.IsNotExist(err) {
return err
}
f, err := os.OpenFile(m.path, os.O_CREATE|os.O_WRONLY, 0o600)
f, err := openFileFn(m.path, os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("knownhosts: create %s: %w", m.path, err)
}
@@ -232,7 +254,7 @@ func (m *manager) ensureKnownHostsFile() error {
}
func appendHostKey(path string, entries [][]byte) error {
f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0o600)
f, err := appendOpenFileFn(path)
if err != nil {
return fmt.Errorf("knownhosts: open %s: %w", path, err)
}
@@ -287,7 +309,7 @@ func normalizeHostEntry(host string, entry []byte) ([]byte, string, error) {
}
func findHostKeyLine(path, host, keyType string) (string, error) {
f, err := os.Open(path)
f, err := openFn(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
@@ -328,10 +350,6 @@ func hostLineMatches(host, line string) bool {
}
fields := strings.Fields(trimmed)
if len(fields) == 0 {
return false
}
return hostFieldMatches(host, fields[0])
}
@@ -396,8 +414,7 @@ func defaultKeyscan(ctx context.Context, host string, port int, timeout time.Dur
}
args = append(args, host)
cmd := exec.CommandContext(scanCtx, "ssh-keyscan", args...)
output, err := cmd.CombinedOutput()
output, err := keyscanCmdRunner(scanCtx, args...)
if err != nil {
return nil, fmt.Errorf("%w (output: %s)", err, strings.TrimSpace(string(output)))
}

View File

@@ -1,15 +1,27 @@
package knownhosts
import (
"bytes"
"context"
"errors"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
)
func resetKnownHostsFns() {
mkdirAllFn = defaultMkdirAllFn
statFn = defaultStatFn
openFileFn = defaultOpenFileFn
openFn = defaultOpenFn
appendOpenFileFn = defaultAppendOpenFileFn
keyscanCmdRunner = defaultKeyscanCmdRunner
}
func TestEnsureCreatesFileAndCaches(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
@@ -110,6 +122,167 @@ func TestEnsureRespectsContextCancellation(t *testing.T) {
}
}
func TestNewManagerEmptyPath(t *testing.T) {
if _, err := NewManager(""); err == nil {
t.Fatal("expected error for empty path")
}
}
func TestEnsureWithPortMissingHost(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
mgr, err := NewManager(path)
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithPort(context.Background(), "", 22); err == nil {
t.Fatal("expected error for missing host")
}
}
func TestEnsureWithPortDefaultsPort(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
var gotPort int
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
gotPort = port
return []byte(host + " ssh-ed25519 AAAA"), nil
}
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithPort(context.Background(), "example.com", 0); err != nil {
t.Fatalf("EnsureWithPort: %v", err)
}
if gotPort != 22 {
t.Fatalf("expected port 22, got %d", gotPort)
}
}
func TestEnsureWithPortCustomPort(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
return []byte("[example.com]:2222 ssh-ed25519 AAAA"), nil
}
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithPort(context.Background(), "example.com", 2222); err != nil {
t.Fatalf("EnsureWithPort: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
if got := strings.TrimSpace(string(data)); got != "[example.com]:2222 ssh-ed25519 AAAA" {
t.Fatalf("unexpected known_hosts contents: %s", got)
}
}
func TestEnsureWithPortKeyscanError(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
return nil, errors.New("scan failed")
}
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithPort(context.Background(), "example.com", 22); err == nil {
t.Fatal("expected keyscan error")
}
}
func TestEnsureWithEntriesMissingHost(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
mgr, err := NewManager(path)
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
t.Fatal("expected missing host error")
}
}
func TestEnsureWithEntriesNoEntries(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
mgr, err := NewManager(path)
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, nil); err == nil {
t.Fatal("expected no entries error")
}
}
func TestEnsureWithEntriesNormalizeError(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
mgr, err := NewManager(path)
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("invalid")}); err == nil {
t.Fatal("expected normalize error")
}
}
func TestEnsureWithEntriesEnsureKnownHostsFileError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
mkdirAllFn = func(string, os.FileMode) error {
return errors.New("mkdir failed")
}
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
t.Fatal("expected ensureKnownHostsFile error")
}
}
func TestEnsureWithEntriesFindHostKeyLineError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
openFn = func(string) (*os.File, error) {
return nil, errors.New("open failed")
}
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
t.Fatal("expected open error")
}
}
func TestHostCandidates(t *testing.T) {
tests := []struct {
input string
@@ -293,3 +466,312 @@ func TestHostFieldMatches(t *testing.T) {
})
}
}
func TestEnsureKnownHostsFileMkdirError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
mkdirAllFn = func(string, os.FileMode) error {
return errors.New("mkdir failed")
}
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
if err := m.ensureKnownHostsFile(); err == nil {
t.Fatal("expected mkdir error")
}
}
func TestEnsureKnownHostsFileStatError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
statFn = func(string) (os.FileInfo, error) {
return nil, errors.New("stat failed")
}
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
if err := m.ensureKnownHostsFile(); err == nil {
t.Fatal("expected stat error")
}
}
func TestEnsureKnownHostsFileCreateError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
statFn = func(string) (os.FileInfo, error) {
return nil, os.ErrNotExist
}
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
return nil, errors.New("open failed")
}
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
if err := m.ensureKnownHostsFile(); err == nil {
t.Fatal("expected create error")
}
}
func TestAppendHostKeyOpenError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
appendOpenFileFn = func(string) (io.WriteCloser, error) {
return nil, errors.New("open failed")
}
if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
t.Fatal("expected open error")
}
}
func TestAppendHostKeyWriteError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
appendOpenFileFn = func(string) (io.WriteCloser, error) {
return errWriteCloser{err: errors.New("write failed")}, nil
}
if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
t.Fatal("expected write error")
}
}
func TestNormalizeHostEntryWithComment(t *testing.T) {
entry := []byte("example.com ssh-ed25519 AAAA comment here")
normalized, keyType, err := normalizeHostEntry("example.com", entry)
if err != nil {
t.Fatalf("normalizeHostEntry error: %v", err)
}
if keyType != "ssh-ed25519" {
t.Fatalf("expected key type ssh-ed25519, got %s", keyType)
}
if string(normalized) != "example.com ssh-ed25519 AAAA comment here" {
t.Fatalf("unexpected normalized entry: %s", string(normalized))
}
}
func TestFindHostKeyLineNotExists(t *testing.T) {
line, err := findHostKeyLine(filepath.Join(t.TempDir(), "missing"), "example.com", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if line != "" {
t.Fatalf("expected empty line, got %q", line)
}
}
func TestFindHostKeyLineOpenError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
openFn = func(string) (*os.File, error) {
return nil, errors.New("open failed")
}
if _, err := findHostKeyLine("ignored", "example.com", ""); err == nil {
t.Fatal("expected open error")
}
}
func TestFindHostKeyLineScannerError(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
longLine := strings.Repeat("a", 70000)
if err := os.WriteFile(path, []byte(longLine+"\n"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
if _, err := findHostKeyLine(path, "example.com", ""); err == nil {
t.Fatal("expected scanner error")
}
}
func TestHostLineMatchesSkips(t *testing.T) {
if hostLineMatches("example.com", "") {
t.Fatal("expected empty line to be false")
}
if hostLineMatches("example.com", "# comment") {
t.Fatal("expected comment line to be false")
}
if hostLineMatches("example.com", "|1|hash|salt ssh-ed25519 AAAA") {
t.Fatal("expected hashed entry to be false")
}
}
func TestDefaultKeyscanSuccess(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
return []byte("example.com ssh-ed25519 AAAA"), nil
}
out, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second)
if err != nil {
t.Fatalf("defaultKeyscan error: %v", err)
}
if string(out) != "example.com ssh-ed25519 AAAA" {
t.Fatalf("unexpected output: %s", string(out))
}
}
func TestDefaultKeyscanError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
return []byte("boom"), errors.New("scan failed")
}
if _, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second); err == nil {
t.Fatal("expected error")
}
}
func TestEnsureWithEntriesDefaultsPort(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
mgr, err := NewManager(path)
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 0, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err != nil {
t.Fatalf("EnsureWithEntries: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
if got := strings.TrimSpace(string(data)); got != "example.com ssh-ed25519 AAAA" {
t.Fatalf("unexpected known_hosts contents: %s", got)
}
}
func TestEnsureWithEntriesAppendError(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
appendOpenFileFn = func(string) (io.WriteCloser, error) {
return nil, errors.New("open failed")
}
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
if err != nil {
t.Fatalf("NewManager: %v", err)
}
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
t.Fatal("expected append error")
}
}
func TestAppendHostKeySkipsEmptyEntry(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
buf := &bufferWriteCloser{}
appendOpenFileFn = func(string) (io.WriteCloser, error) {
return buf, nil
}
if err := appendHostKey("ignored", [][]byte{nil, []byte("example.com ssh-ed25519 AAAA")}); err != nil {
t.Fatalf("appendHostKey error: %v", err)
}
if !strings.Contains(buf.String(), "example.com ssh-ed25519 AAAA") {
t.Fatalf("expected entry to be written, got %q", buf.String())
}
}
func TestFindHostKeyLineSkipsInvalidLines(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "known_hosts")
contents := strings.Join([]string{
"other.com ssh-ed25519 AAAA",
"example.com ssh-ed25519",
"example.com ssh-ed25519 AAAA",
}, "\n") + "\n"
if err := os.WriteFile(path, []byte(contents), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
line, err := findHostKeyLine(path, "example.com", "ssh-ed25519")
if err != nil {
t.Fatalf("findHostKeyLine error: %v", err)
}
if line != "example.com ssh-ed25519 AAAA" {
t.Fatalf("unexpected line: %q", line)
}
}
func TestDefaultKeyscanArgs(t *testing.T) {
t.Cleanup(resetKnownHostsFns)
var gotArgs []string
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
gotArgs = append([]string{}, args...)
return []byte("ok"), nil
}
if _, err := defaultKeyscan(context.Background(), "example.com", 0, 0); err != nil {
t.Fatalf("defaultKeyscan error: %v", err)
}
for _, arg := range gotArgs {
if arg == "-p" {
t.Fatal("did not expect -p for default port")
}
}
if len(gotArgs) < 3 || gotArgs[len(gotArgs)-1] != "example.com" {
t.Fatalf("unexpected args: %v", gotArgs)
}
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
gotArgs = append([]string{}, args...)
return []byte("ok"), nil
}
if _, err := defaultKeyscan(context.Background(), "example.com", 2222, time.Second); err != nil {
t.Fatalf("defaultKeyscan error: %v", err)
}
hasPort := false
for i := 0; i < len(gotArgs)-1; i++ {
if gotArgs[i] == "-p" && gotArgs[i+1] == "2222" {
hasPort = true
break
}
}
if !hasPort {
t.Fatalf("expected -p 2222 in args, got %v", gotArgs)
}
}
func TestKeyscanCmdRunnerDefault(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("ssh-keyscan helper script requires sh")
}
t.Cleanup(resetKnownHostsFns)
dir := t.TempDir()
scriptPath := filepath.Join(dir, "ssh-keyscan")
script := []byte("#!/bin/sh\necho example.com ssh-ed25519 AAAA\n")
if err := os.WriteFile(scriptPath, script, 0700); err != nil {
t.Fatalf("failed to write script: %v", err)
}
oldPath := os.Getenv("PATH")
if err := os.Setenv("PATH", dir+string(os.PathListSeparator)+oldPath); err != nil {
t.Fatalf("failed to set PATH: %v", err)
}
t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) })
output, err := keyscanCmdRunner(context.Background(), "example.com")
if err != nil {
t.Fatalf("keyscanCmdRunner error: %v", err)
}
if strings.TrimSpace(string(output)) != "example.com ssh-ed25519 AAAA" {
t.Fatalf("unexpected output: %s", string(output))
}
}
type errWriteCloser struct {
err error
}
func (e errWriteCloser) Write(p []byte) (int, error) {
return 0, e.err
}
func (e errWriteCloser) Close() error {
return nil
}
type bufferWriteCloser struct {
bytes.Buffer
}
func (b *bufferWriteCloser) Close() error {
return nil
}

View File

@@ -5,6 +5,13 @@ import (
"strings"
)
var (
envGetFn = os.Getenv
statFn = os.Stat
readFileFn = os.ReadFile
hostnameFn = os.Hostname
)
var containerMarkers = []string{
"docker",
"lxc",
@@ -19,24 +26,24 @@ var containerMarkers = []string{
// InContainer reports whether Pulse is running inside a containerised environment.
func InContainer() bool {
// Allow operators to force container behaviour when automatic detection falls short.
if isTruthy(os.Getenv("PULSE_FORCE_CONTAINER")) {
if isTruthy(envGetFn("PULSE_FORCE_CONTAINER")) {
return true
}
if _, err := os.Stat("/.dockerenv"); err == nil {
if _, err := statFn("/.dockerenv"); err == nil {
return true
}
if _, err := os.Stat("/run/.containerenv"); err == nil {
if _, err := statFn("/run/.containerenv"); err == nil {
return true
}
// Check common environment hints provided by systemd/nspawn, LXC, etc.
if val := strings.ToLower(strings.TrimSpace(os.Getenv("container"))); val != "" && val != "host" {
if val := strings.ToLower(strings.TrimSpace(envGetFn("container"))); val != "" && val != "host" {
return true
}
// Some distros expose the container hint through PID 1's environment.
if data, err := os.ReadFile("/proc/1/environ"); err == nil {
if data, err := readFileFn("/proc/1/environ"); err == nil {
lower := strings.ToLower(string(data))
if strings.Contains(lower, "container=") && !strings.Contains(lower, "container=host") {
return true
@@ -44,7 +51,7 @@ func InContainer() bool {
}
// Fall back to cgroup inspection which covers older Docker/LXC setups.
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
if data, err := readFileFn("/proc/1/cgroup"); err == nil {
content := strings.ToLower(string(data))
for _, marker := range containerMarkers {
if strings.Contains(content, marker) {
@@ -60,7 +67,7 @@ func InContainer() bool {
// Returns empty string if not in Docker or name cannot be determined.
func DetectDockerContainerName() string {
// Method 1: Check hostname (Docker uses container ID or name as hostname)
if hostname, err := os.Hostname(); err == nil && hostname != "" {
if hostname, err := hostnameFn(); err == nil && hostname != "" {
// Docker hostnames are either short container ID (12 chars) or custom name
// If it looks like a container ID (hex), skip it - user needs to use name
if !isHexString(hostname) || len(hostname) > 12 {
@@ -69,7 +76,7 @@ func DetectDockerContainerName() string {
}
// Method 2: Try reading from /proc/self/cgroup
if data, err := os.ReadFile("/proc/self/cgroup"); err == nil {
if data, err := readFileFn("/proc/self/cgroup"); err == nil {
// Look for patterns like: 0::/docker/<container-id>
// But we can't get name from cgroup, only ID
_ = data // placeholder for future enhancement
@@ -91,7 +98,7 @@ func isHexString(s string) bool {
// Returns empty string if not in an LXC container or CTID cannot be determined.
func DetectLXCCTID() string {
// Method 1: Parse /proc/1/cgroup for LXC container ID
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
if data, err := readFileFn("/proc/1/cgroup"); err == nil {
lines := strings.Split(string(data), "\n")
for _, line := range lines {
// Look for patterns like: 0::/lxc/123 or 0::/lxc.payload.123
@@ -119,7 +126,7 @@ func DetectLXCCTID() string {
}
// Method 2: Check hostname (some LXC containers use CTID as hostname)
if hostname, err := os.Hostname(); err == nil && isNumeric(hostname) {
if hostname, err := hostnameFn(); err == nil && isNumeric(hostname) {
return hostname
}

View File

@@ -1,9 +1,18 @@
package system
import (
"errors"
"os"
"testing"
)
func resetSystemFns() {
envGetFn = os.Getenv
statFn = os.Stat
readFileFn = os.ReadFile
hostnameFn = os.Hostname
}
func TestIsHexString(t *testing.T) {
tests := []struct {
input string
@@ -133,20 +142,226 @@ func TestContainerMarkers(t *testing.T) {
}
func TestInContainer(t *testing.T) {
// This is an integration test that depends on the runtime environment
// We can't mock the file system easily, but we can verify it doesn't panic
result := InContainer()
t.Logf("InContainer() = %v (depends on test environment)", result)
t.Run("Forced", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(key string) string {
if key == "PULSE_FORCE_CONTAINER" {
return "true"
}
return ""
}
if !InContainer() {
t.Fatal("expected forced container")
}
})
t.Run("DockerEnvFile", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(string) string { return "" }
statFn = func(path string) (os.FileInfo, error) {
if path == "/.dockerenv" {
return nil, nil
}
return nil, errors.New("missing")
}
readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") }
if !InContainer() {
t.Fatal("expected docker env file")
}
})
t.Run("ContainerEnvFile", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(string) string { return "" }
statFn = func(path string) (os.FileInfo, error) {
if path == "/run/.containerenv" {
return nil, nil
}
return nil, errors.New("missing")
}
readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") }
if !InContainer() {
t.Fatal("expected container env file")
}
})
t.Run("ContainerEnvVar", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(key string) string {
if key == "container" {
return "lxc"
}
return ""
}
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") }
if !InContainer() {
t.Fatal("expected container env var")
}
})
t.Run("ProcEnviron", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(string) string { return "" }
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
readFileFn = func(path string) ([]byte, error) {
if path == "/proc/1/environ" {
return []byte("container=lxc\x00"), nil
}
return nil, errors.New("missing")
}
if !InContainer() {
t.Fatal("expected container environ")
}
})
t.Run("CgroupMarker", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(string) string { return "" }
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
readFileFn = func(path string) ([]byte, error) {
if path == "/proc/1/environ" {
return []byte("container=host"), nil
}
if path == "/proc/1/cgroup" {
return []byte("0::/docker/abc"), nil
}
return nil, errors.New("missing")
}
if !InContainer() {
t.Fatal("expected container cgroup marker")
}
})
t.Run("NotContainer", func(t *testing.T) {
t.Cleanup(resetSystemFns)
envGetFn = func(string) string { return "" }
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
readFileFn = func(path string) ([]byte, error) {
if path == "/proc/1/environ" {
return []byte("container=host"), nil
}
if path == "/proc/1/cgroup" {
return []byte("0::/user.slice"), nil
}
return nil, errors.New("missing")
}
if InContainer() {
t.Fatal("expected non-container")
}
})
}
func TestDetectDockerContainerName(t *testing.T) {
// This is an integration test that depends on the runtime environment
result := DetectDockerContainerName()
t.Logf("DetectDockerContainerName() = %q (depends on test environment)", result)
t.Run("HostnameName", func(t *testing.T) {
t.Cleanup(resetSystemFns)
hostnameFn = func() (string, error) { return "my-container", nil }
if got := DetectDockerContainerName(); got != "my-container" {
t.Fatalf("expected hostname name, got %q", got)
}
})
t.Run("HostnameHexShort", func(t *testing.T) {
t.Cleanup(resetSystemFns)
hostnameFn = func() (string, error) { return "abcdef123456", nil }
readFileFn = func(string) ([]byte, error) { return []byte("0::/docker/abcdef"), nil }
if got := DetectDockerContainerName(); got != "" {
t.Fatalf("expected empty name, got %q", got)
}
})
t.Run("HostnameHexLong", func(t *testing.T) {
t.Cleanup(resetSystemFns)
hostnameFn = func() (string, error) { return "abcdef1234567890abcdef1234567890", nil }
if got := DetectDockerContainerName(); got == "" {
t.Fatal("expected hostname for long hex")
}
})
t.Run("HostnameError", func(t *testing.T) {
t.Cleanup(resetSystemFns)
hostnameFn = func() (string, error) { return "", errors.New("fail") }
readFileFn = func(string) ([]byte, error) { return []byte("0::/docker/abcdef"), nil }
if got := DetectDockerContainerName(); got != "" {
t.Fatalf("expected empty name, got %q", got)
}
})
}
func TestDetectLXCCTID(t *testing.T) {
// This is an integration test that depends on the runtime environment
result := DetectLXCCTID()
t.Logf("DetectLXCCTID() = %q (depends on test environment)", result)
t.Run("FromCgroupLXC", func(t *testing.T) {
t.Cleanup(resetSystemFns)
readFileFn = func(path string) ([]byte, error) {
if path == "/proc/1/cgroup" {
return []byte("0::/lxc/123"), nil
}
return nil, errors.New("missing")
}
hostnameFn = func() (string, error) { return "999", nil }
if got := DetectLXCCTID(); got != "123" {
t.Fatalf("expected CTID 123, got %q", got)
}
})
t.Run("FromCgroupPayload", func(t *testing.T) {
t.Cleanup(resetSystemFns)
readFileFn = func(path string) ([]byte, error) {
if path == "/proc/1/cgroup" {
return []byte("0::/lxc.payload.456"), nil
}
return nil, errors.New("missing")
}
hostnameFn = func() (string, error) { return "999", nil }
if got := DetectLXCCTID(); got != "456" {
t.Fatalf("expected CTID 456, got %q", got)
}
})
t.Run("FromMachineLXC", func(t *testing.T) {
t.Cleanup(resetSystemFns)
readFileFn = func(path string) ([]byte, error) {
if path == "/proc/1/cgroup" {
return []byte("0::/machine.slice/machine-lxc-789"), nil
}
return nil, errors.New("missing")
}
hostnameFn = func() (string, error) { return "999", nil }
if got := DetectLXCCTID(); got != "789" {
t.Fatalf("expected CTID 789, got %q", got)
}
})
t.Run("FromHostname", func(t *testing.T) {
t.Cleanup(resetSystemFns)
readFileFn = func(string) ([]byte, error) { return []byte("0::/user.slice"), nil }
hostnameFn = func() (string, error) { return "321", nil }
if got := DetectLXCCTID(); got != "321" {
t.Fatalf("expected CTID 321, got %q", got)
}
})
t.Run("NoMatch", func(t *testing.T) {
t.Cleanup(resetSystemFns)
readFileFn = func(string) ([]byte, error) { return []byte("0::/user.slice"), nil }
hostnameFn = func() (string, error) { return "host", nil }
if got := DetectLXCCTID(); got != "" {
t.Fatalf("expected empty CTID, got %q", got)
}
})
}

View File

@@ -24,6 +24,13 @@ const (
maxRetries = 3
)
var statFn = os.Stat
var dialContextFn = func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) {
dialer := net.Dialer{Timeout: timeout}
return dialer.DialContext(ctx, network, address)
}
// ErrorType classifies proxy errors for better error handling
type ErrorType int
@@ -66,9 +73,9 @@ type Client struct {
func NewClient() *Client {
socketPath := os.Getenv("PULSE_SENSOR_PROXY_SOCKET")
if socketPath == "" {
if _, err := os.Stat(defaultSocketPath); err == nil {
if _, err := statFn(defaultSocketPath); err == nil {
socketPath = defaultSocketPath
} else if _, err := os.Stat(containerSocketPath); err == nil {
} else if _, err := statFn(containerSocketPath); err == nil {
socketPath = containerSocketPath
} else {
socketPath = defaultSocketPath
@@ -83,7 +90,7 @@ func NewClient() *Client {
// IsAvailable checks if the proxy is running and accessible
func (c *Client) IsAvailable() bool {
_, err := os.Stat(c.socketPath)
_, err := statFn(c.socketPath)
return err == nil
}
@@ -312,13 +319,8 @@ func (c *Client) callWithContext(ctx context.Context, method string, params map[
// callOnce sends a single RPC request without retries
func (c *Client) callOnce(ctx context.Context, method string, params map[string]interface{}) (*RPCResponse, error) {
// Create a dialer with context
dialer := net.Dialer{
Timeout: c.timeout,
}
// Connect to unix socket with context
conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
conn, err := dialContextFn(ctx, "unix", c.socketPath, c.timeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to proxy: %w", err)
}
@@ -366,10 +368,6 @@ func (c *Client) GetStatus() (map[string]interface{}, error) {
return nil, err
}
if !resp.Success {
return nil, fmt.Errorf("proxy error: %s", resp.Error)
}
return resp.Data, nil
}
@@ -380,13 +378,6 @@ func (c *Client) RegisterNodes() ([]map[string]interface{}, error) {
return nil, err
}
if !resp.Success {
if proxyErr := classifyError(nil, resp.Error); proxyErr != nil {
return nil, proxyErr
}
return nil, fmt.Errorf("proxy error: %s", resp.Error)
}
// Extract nodes array from data
nodesRaw, ok := resp.Data["nodes"]
if !ok {
@@ -422,18 +413,6 @@ func (c *Client) GetTemperature(nodeHost string) (string, error) {
return "", err
}
if !resp.Success {
if proxyErr := classifyError(nil, resp.Error); proxyErr != nil {
return "", proxyErr
}
return "", &ProxyError{
Type: ErrorTypeUnknown,
Message: resp.Error,
Retryable: false,
Wrapped: fmt.Errorf("%s", resp.Error),
}
}
// Extract temperature JSON string
tempRaw, ok := resp.Data["temperature"]
if !ok {
@@ -455,17 +434,9 @@ func (c *Client) RequestCleanup(host string) error {
params["host"] = host
}
resp, err := c.call("request_cleanup", params)
if err != nil {
if _, err := c.call("request_cleanup", params); err != nil {
return err
}
if !resp.Success {
if resp.Error != "" {
return fmt.Errorf("proxy error: %s", resp.Error)
}
return fmt.Errorf("proxy rejected cleanup request")
}
return nil
}

View File

@@ -0,0 +1,488 @@
package tempproxy
import (
"context"
"encoding/json"
"errors"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
type writeErrConn struct {
net.Conn
}
func (c writeErrConn) Write(p []byte) (int, error) {
return 0, errors.New("write failed")
}
func startUnixServer(t *testing.T, handler func(net.Conn)) (string, func()) {
t.Helper()
dir := t.TempDir()
socketPath := filepath.Join(dir, "proxy.sock")
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go handler(conn)
}
}()
cleanup := func() {
ln.Close()
<-done
}
return socketPath, cleanup
}
func startRPCServer(t *testing.T, handler func(RPCRequest) RPCResponse) (string, func()) {
t.Helper()
return startUnixServer(t, func(conn net.Conn) {
defer conn.Close()
decoder := json.NewDecoder(conn)
var req RPCRequest
if err := decoder.Decode(&req); err != nil {
return
}
resp := handler(req)
encoder := json.NewEncoder(conn)
_ = encoder.Encode(resp)
})
}
func TestNewClientSocketSelection(t *testing.T) {
origStat := statFn
t.Cleanup(func() { statFn = origStat })
t.Run("EnvOverride", func(t *testing.T) {
origEnv := os.Getenv("PULSE_SENSOR_PROXY_SOCKET")
t.Cleanup(func() {
if origEnv == "" {
os.Unsetenv("PULSE_SENSOR_PROXY_SOCKET")
} else {
os.Setenv("PULSE_SENSOR_PROXY_SOCKET", origEnv)
}
})
os.Setenv("PULSE_SENSOR_PROXY_SOCKET", "/tmp/custom.sock")
statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist }
client := NewClient()
if client.socketPath != "/tmp/custom.sock" {
t.Fatalf("expected env socket, got %q", client.socketPath)
}
})
t.Run("DefaultExists", func(t *testing.T) {
statFn = func(path string) (os.FileInfo, error) {
if path == defaultSocketPath {
return nil, nil
}
return nil, os.ErrNotExist
}
client := NewClient()
if client.socketPath != defaultSocketPath {
t.Fatalf("expected default socket, got %q", client.socketPath)
}
})
t.Run("ContainerExists", func(t *testing.T) {
statFn = func(path string) (os.FileInfo, error) {
if path == containerSocketPath {
return nil, nil
}
return nil, os.ErrNotExist
}
client := NewClient()
if client.socketPath != containerSocketPath {
t.Fatalf("expected container socket, got %q", client.socketPath)
}
})
t.Run("FallbackDefault", func(t *testing.T) {
statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist }
client := NewClient()
if client.socketPath != defaultSocketPath {
t.Fatalf("expected fallback socket, got %q", client.socketPath)
}
})
}
func TestClientIsAvailable(t *testing.T) {
origStat := statFn
t.Cleanup(func() { statFn = origStat })
client := &Client{socketPath: "/tmp/socket"}
statFn = func(string) (os.FileInfo, error) { return nil, nil }
if !client.IsAvailable() {
t.Fatalf("expected available")
}
statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist }
if client.IsAvailable() {
t.Fatalf("expected unavailable")
}
}
func TestCallOnceSuccessAndDeadline(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
if req.Method != "ping" {
t.Fatalf("unexpected method: %s", req.Method)
}
return RPCResponse{Success: true}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
resp, err := client.callOnce(context.Background(), "ping", nil)
if err != nil || resp == nil || !resp.Success {
t.Fatalf("expected success: resp=%v err=%v", resp, err)
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
resp, err = client.callOnce(ctx, "ping", nil)
if err != nil || resp == nil || !resp.Success {
t.Fatalf("expected success with deadline: resp=%v err=%v", resp, err)
}
}
func TestCallOnceErrors(t *testing.T) {
t.Run("ConnectError", func(t *testing.T) {
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected connect error")
}
})
t.Run("EncodeError", func(t *testing.T) {
socketPath, cleanup := startUnixServer(t, func(conn net.Conn) {
conn.Close()
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected encode error")
}
})
t.Run("EncodeErrorDial", func(t *testing.T) {
origDial := dialContextFn
t.Cleanup(func() { dialContextFn = origDial })
clientConn, serverConn := net.Pipe()
t.Cleanup(func() { _ = serverConn.Close() })
dialContextFn = func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) {
return writeErrConn{Conn: clientConn}, nil
}
client := &Client{socketPath: "/tmp/unused.sock", timeout: 50 * time.Millisecond}
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected encode error")
}
})
t.Run("DecodeError", func(t *testing.T) {
socketPath, cleanup := startUnixServer(t, func(conn net.Conn) {
defer conn.Close()
_, _ = conn.Write([]byte("{"))
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected decode error")
}
})
}
func TestCallWithContextBehavior(t *testing.T) {
t.Run("Success", func(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
resp, err := client.callWithContext(context.Background(), "ping", nil)
if err != nil || resp == nil || !resp.Success {
t.Fatalf("expected success: resp=%v err=%v", resp, err)
}
})
t.Run("NonRetryable", func(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "unauthorized"}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected auth error")
}
})
t.Run("CancelledBeforeRetry", func(t *testing.T) {
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := client.callWithContext(ctx, "ping", nil); err == nil {
t.Fatalf("expected cancelled error")
}
})
t.Run("CancelledDuringBackoff", func(t *testing.T) {
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
if _, err := client.callWithContext(ctx, "ping", nil); err == nil {
t.Fatalf("expected backoff cancel error")
}
})
t.Run("MaxRetriesExhausted", func(t *testing.T) {
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected retry error")
}
})
}
func TestGetStatus(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{"ok": true}}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
data, err := client.GetStatus()
if err != nil || data["ok"] != true {
t.Fatalf("unexpected status: %v err=%v", data, err)
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "boom"}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.GetStatus(); err == nil {
t.Fatalf("expected status error")
}
}
func TestRegisterNodes(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{
"nodes": []interface{}{map[string]interface{}{"name": "node1"}},
}}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
nodes, err := client.RegisterNodes()
if err != nil || len(nodes) != 1 || nodes[0]["name"] != "node1" {
t.Fatalf("unexpected nodes: %v err=%v", nodes, err)
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "unauthorized"}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.RegisterNodes(); err == nil {
t.Fatalf("expected proxy error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{}}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.RegisterNodes(); err == nil {
t.Fatalf("expected missing nodes error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{"nodes": "bad"}}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.RegisterNodes(); err == nil {
t.Fatalf("expected nodes type error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{"nodes": []interface{}{"bad"}}}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.RegisterNodes(); err == nil {
t.Fatalf("expected node map error")
}
}
func TestGetTemperature(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{"temperature": "42"}}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
temp, err := client.GetTemperature("node1")
if err != nil || temp != "42" {
t.Fatalf("unexpected temp: %q err=%v", temp, err)
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "unauthorized"}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.GetTemperature("node1"); err == nil {
t.Fatalf("expected proxy error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "boom"}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.GetTemperature("node1"); err == nil {
t.Fatalf("expected unknown error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{}}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.GetTemperature("node1"); err == nil {
t.Fatalf("expected missing temperature error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true, Data: map[string]interface{}{"temperature": 12}}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.GetTemperature("node1"); err == nil {
t.Fatalf("expected type error")
}
}
func TestRequestCleanup(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
if req.Method != "request_cleanup" {
t.Fatalf("unexpected method: %s", req.Method)
}
if req.Params["host"] != "node1" {
t.Fatalf("unexpected params: %v", req.Params)
}
return RPCResponse{Success: true}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if err := client.RequestCleanup("node1"); err != nil {
t.Fatalf("unexpected cleanup error: %v", err)
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "boom"}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if err := client.RequestCleanup(""); err == nil {
t.Fatalf("expected proxy error")
}
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: ""}
})
t.Cleanup(cleanup)
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if err := client.RequestCleanup(""); err == nil {
t.Fatalf("expected rejected error")
}
client = &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
if err := client.RequestCleanup(""); err == nil {
t.Fatalf("expected call error")
}
}
func TestCallWrapper(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: true}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
resp, err := client.call("ping", nil)
if err != nil || resp == nil || !resp.Success {
t.Fatalf("expected call success: resp=%v err=%v", resp, err)
}
}
func TestCallWithContextRetryableResponseError(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: "ssh timeout"}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected retry exhaustion")
}
}
func TestCallWithContextSuccessOnRetry(t *testing.T) {
count := 0
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
count++
if count == 1 {
return RPCResponse{Success: false, Error: "ssh timeout"}
}
return RPCResponse{Success: true}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
resp, err := client.callWithContext(context.Background(), "ping", nil)
if err != nil || resp == nil || !resp.Success {
t.Fatalf("expected retry success: resp=%v err=%v", resp, err)
}
}
func TestCallWithContextRespErrorString(t *testing.T) {
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
return RPCResponse{Success: false, Error: strings.Repeat("x", 1)}
})
t.Cleanup(cleanup)
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
t.Fatalf("expected error")
}
}

View File

@@ -200,6 +200,12 @@ func TestContains(t *testing.T) {
substrs: []string{"ssh error"},
expected: true,
},
{
name: "case insensitive upper substr",
s: "connection refused",
substrs: []string{"REFUSED"},
expected: true,
},
{
name: "empty string",
s: "",

View File

@@ -80,7 +80,11 @@ func (c *HTTPClient) GetTemperature(nodeHost string) (string, error) {
req.Header.Set("Accept", "application/json")
// Execute request with retries
var lastErr error
var lastErr error = &ProxyError{
Type: ErrorTypeUnknown,
Message: "all retry attempts failed",
Retryable: false,
}
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
backoff := calculateBackoff(attempt)
@@ -167,15 +171,7 @@ func (c *HTTPClient) GetTemperature(nodeHost string) (string, error) {
}
// All retries exhausted
if lastErr != nil {
return "", lastErr
}
return "", &ProxyError{
Type: ErrorTypeUnknown,
Message: "all retry attempts failed",
Retryable: false,
}
return "", lastErr
}
// HealthCheck calls the proxy /health endpoint to verify connectivity.

View File

@@ -2,6 +2,8 @@ package tempproxy
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
@@ -360,6 +362,55 @@ func TestHTTPClient_GetTemperature_RetryOnTransportError(t *testing.T) {
}
}
func TestHTTPClient_GetTemperature_RequestBuildError(t *testing.T) {
client := NewHTTPClient("http://[::1", "token")
_, err := client.GetTemperature("node1")
if err == nil {
t.Fatal("Expected error for request creation")
}
proxyErr, ok := err.(*ProxyError)
if !ok {
t.Fatalf("Expected *ProxyError, got %T", err)
}
if proxyErr.Type != ErrorTypeTransport {
t.Errorf("Type = %v, want ErrorTypeTransport", proxyErr.Type)
}
if !strings.Contains(proxyErr.Message, "failed to create HTTP request") {
t.Errorf("Message = %q, want request creation failure", proxyErr.Message)
}
}
func TestHTTPClient_GetTemperature_ReadBodyError(t *testing.T) {
attempts := 0
client := NewHTTPClient("https://example.com", "token")
client.httpClient.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) {
attempts++
if attempts == 1 {
return &http.Response{
StatusCode: http.StatusOK,
Body: errReadCloser{},
Header: make(http.Header),
}, nil
}
body := `{"node":"node1","temperature":"{}"}`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}, nil
})
_, err := client.GetTemperature("node1")
if err != nil {
t.Fatalf("Expected success after retry, got %v", err)
}
if attempts < 2 {
t.Fatalf("Expected retry after read error, got %d attempts", attempts)
}
}
func TestHTTPClient_HealthCheck_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/health" {
@@ -398,6 +449,26 @@ func TestHTTPClient_HealthCheck_NotConfigured(t *testing.T) {
}
}
func TestHTTPClient_HealthCheck_RequestBuildError(t *testing.T) {
client := NewHTTPClient("http://[::1", "token")
err := client.HealthCheck()
if err == nil {
t.Fatal("Expected error for request creation")
}
proxyErr, ok := err.(*ProxyError)
if !ok {
t.Fatalf("Expected *ProxyError, got %T", err)
}
if proxyErr.Type != ErrorTypeTransport {
t.Errorf("Type = %v, want ErrorTypeTransport", proxyErr.Type)
}
if !strings.Contains(proxyErr.Message, "failed to create HTTP request") {
t.Errorf("Message = %q, want request creation failure", proxyErr.Message)
}
}
func TestHTTPClient_HealthCheck_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
@@ -493,3 +564,19 @@ func TestHTTPClient_Fields(t *testing.T) {
t.Errorf("timeout = %v, want 60s", client.timeout)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
type errReadCloser struct{}
func (errReadCloser) Read(p []byte) (int, error) {
return 0, errors.New("read failed")
}
func (errReadCloser) Close() error {
return nil
}

View File

@@ -0,0 +1,326 @@
package updatedetection
import (
"context"
"testing"
"time"
"github.com/rs/zerolog"
)
func TestDefaultManagerConfig(t *testing.T) {
cfg := DefaultManagerConfig()
if !cfg.Enabled {
t.Fatalf("expected enabled by default")
}
if cfg.CheckInterval != 6*time.Hour {
t.Fatalf("expected check interval 6h, got %v", cfg.CheckInterval)
}
if cfg.AlertDelayHours != 24 {
t.Fatalf("expected alert delay 24, got %d", cfg.AlertDelayHours)
}
if !cfg.EnableDockerUpdates {
t.Fatalf("expected docker updates enabled by default")
}
}
func TestManagerConfigAccessors(t *testing.T) {
cfg := ManagerConfig{
Enabled: false,
CheckInterval: time.Minute,
AlertDelayHours: 12,
EnableDockerUpdates: false,
}
mgr := NewManager(cfg, zerolog.Nop())
if mgr.Enabled() {
t.Fatalf("expected disabled manager")
}
mgr.SetEnabled(true)
if !mgr.Enabled() {
t.Fatalf("expected enabled manager after SetEnabled")
}
if mgr.AlertDelayHours() != 12 {
t.Fatalf("expected alert delay 12, got %d", mgr.AlertDelayHours())
}
}
func TestManagerProcessDockerContainerUpdate(t *testing.T) {
now := time.Now()
status := &ContainerUpdateStatus{
UpdateAvailable: true,
CurrentDigest: "sha256:old",
LatestDigest: "sha256:new",
LastChecked: now,
Error: "boom",
}
t.Run("Disabled", func(t *testing.T) {
cfg := DefaultManagerConfig()
cfg.Enabled = false
mgr := NewManager(cfg, zerolog.Nop())
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status)
if mgr.store.Count() != 0 {
t.Fatalf("expected no updates when disabled")
}
})
t.Run("DockerUpdatesDisabled", func(t *testing.T) {
cfg := DefaultManagerConfig()
cfg.EnableDockerUpdates = false
mgr := NewManager(cfg, zerolog.Nop())
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status)
if mgr.store.Count() != 0 {
t.Fatalf("expected no updates when docker updates disabled")
}
})
t.Run("NilStatus", func(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", nil)
if mgr.store.Count() != 0 {
t.Fatalf("expected no updates for nil status")
}
})
t.Run("NoUpdateDeletes", func(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
mgr.store.UpsertUpdate(&UpdateInfo{
ID: "docker:host-1:container-1",
ResourceID: "container-1",
HostID: "host-1",
})
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", &ContainerUpdateStatus{
UpdateAvailable: false,
LastChecked: now,
})
if mgr.store.Count() != 0 {
t.Fatalf("expected update to be deleted when no update available")
}
})
t.Run("UpdateAvailable", func(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status)
update := mgr.store.GetUpdatesForResource("container-1")
if update == nil {
t.Fatalf("expected update to be stored")
}
if update.ID != "docker:host-1:container-1" {
t.Fatalf("unexpected update ID %q", update.ID)
}
if update.Error != "boom" {
t.Fatalf("expected error to be stored")
}
if update.CurrentVersion != "nginx:latest" {
t.Fatalf("expected current version to be stored")
}
})
}
func TestManagerCheckImageUpdate(t *testing.T) {
ctx := context.Background()
t.Run("Disabled", func(t *testing.T) {
cfg := DefaultManagerConfig()
cfg.Enabled = false
mgr := NewManager(cfg, zerolog.Nop())
info, err := mgr.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info != nil {
t.Fatalf("expected nil info when disabled")
}
})
t.Run("Cached", func(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
cacheKey := "registry-1.docker.io/library/nginx:latest"
mgr.registry.cacheDigest(cacheKey, "sha256:new")
info, err := mgr.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil || !info.UpdateAvailable {
t.Fatalf("expected update available from cache")
}
})
}
func TestManagerGetUpdatesAndAccessors(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
now := time.Now()
update1 := &UpdateInfo{
ID: "update-1",
ResourceID: "res-1",
ResourceType: "docker",
ResourceName: "nginx",
HostID: "host-1",
Type: UpdateTypeDockerImage,
Severity: SeveritySecurity,
LastChecked: now,
}
update2 := &UpdateInfo{
ID: "update-2",
ResourceID: "res-2",
ResourceType: "vm",
ResourceName: "vm-1",
HostID: "host-2",
Type: UpdateTypePackage,
LastChecked: now.Add(-time.Minute),
}
mgr.store.UpsertUpdate(update1)
mgr.store.UpsertUpdate(update2)
if mgr.GetTotalCount() != 2 {
t.Fatalf("expected total count 2, got %d", mgr.GetTotalCount())
}
all := mgr.GetUpdates(UpdateFilters{})
if len(all) != 2 {
t.Fatalf("expected 2 updates, got %d", len(all))
}
filtered := mgr.GetUpdates(UpdateFilters{HostID: "host-1"})
if len(filtered) != 1 || filtered[0].ID != "update-1" {
t.Fatalf("expected one update for host-1")
}
if len(mgr.GetUpdatesForHost("host-1")) != 1 {
t.Fatalf("expected 1 update for host-1")
}
if mgr.GetUpdatesForResource("res-2") == nil {
t.Fatalf("expected update for resource res-2")
}
summary := mgr.GetSummary()
if summary["host-1"].TotalUpdates != 1 {
t.Fatalf("expected summary for host-1")
}
mgr.AddRegistryConfig(RegistryConfig{Host: "example.com", Username: "user"})
if _, ok := mgr.registry.configs["example.com"]; !ok {
t.Fatalf("expected registry config to be stored")
}
mgr.DeleteUpdatesForHost("host-1")
if len(mgr.GetUpdatesForHost("host-1")) != 0 {
t.Fatalf("expected updates removed for host-1")
}
}
func TestManagerCleanupStale(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
now := time.Now()
mgr.store.UpsertUpdate(&UpdateInfo{
ID: "stale",
ResourceID: "res-1",
HostID: "host-1",
LastChecked: now.Add(-2 * time.Hour),
})
mgr.store.UpsertUpdate(&UpdateInfo{
ID: "fresh",
ResourceID: "res-2",
HostID: "host-1",
LastChecked: now,
})
removed := mgr.CleanupStale(time.Hour)
if removed != 1 {
t.Fatalf("expected 1 stale update removed, got %d", removed)
}
if mgr.GetTotalCount() != 1 {
t.Fatalf("expected one update remaining")
}
}
func TestManagerGetUpdatesReadyForAlert(t *testing.T) {
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
mgr.alertDelayHours = 1
now := time.Now()
mgr.store.UpsertUpdate(&UpdateInfo{
ID: "ready",
ResourceID: "res-1",
HostID: "host-1",
FirstDetected: now.Add(-2 * time.Hour),
LastChecked: now,
})
mgr.store.UpsertUpdate(&UpdateInfo{
ID: "recent",
ResourceID: "res-2",
HostID: "host-1",
FirstDetected: now.Add(-30 * time.Minute),
LastChecked: now,
})
mgr.store.UpsertUpdate(&UpdateInfo{
ID: "error",
ResourceID: "res-3",
HostID: "host-1",
FirstDetected: now.Add(-2 * time.Hour),
LastChecked: now,
Error: "rate limited",
})
ready := mgr.GetUpdatesReadyForAlert()
if len(ready) != 1 || ready[0].ID != "ready" {
t.Fatalf("expected only ready update, got %d", len(ready))
}
}
func TestUpdateFilters(t *testing.T) {
update := &UpdateInfo{
HostID: "host-1",
ResourceType: "docker",
Type: UpdateTypeDockerImage,
Severity: SeveritySecurity,
Error: "",
}
if !(&UpdateFilters{}).IsEmpty() {
t.Fatalf("expected empty filters to be empty")
}
if (&UpdateFilters{HostID: "host-1"}).IsEmpty() {
t.Fatalf("expected non-empty filters")
}
if (&UpdateFilters{HostID: "other"}).Matches(update) {
t.Fatalf("expected host mismatch to fail")
}
if (&UpdateFilters{ResourceType: "vm"}).Matches(update) {
t.Fatalf("expected resource type mismatch to fail")
}
if (&UpdateFilters{UpdateType: UpdateTypePackage}).Matches(update) {
t.Fatalf("expected update type mismatch to fail")
}
if (&UpdateFilters{Severity: SeverityBugfix}).Matches(update) {
t.Fatalf("expected severity mismatch to fail")
}
hasError := true
if (&UpdateFilters{HasError: &hasError}).Matches(update) {
t.Fatalf("expected error filter to fail")
}
update.Error = "boom"
hasError = false
if (&UpdateFilters{HasError: &hasError}).Matches(update) {
t.Fatalf("expected error=false filter to fail")
}
update.Error = ""
hasError = false
filters := UpdateFilters{
HostID: "host-1",
ResourceType: "docker",
UpdateType: UpdateTypeDockerImage,
Severity: SeveritySecurity,
HasError: &hasError,
}
if !filters.Matches(update) {
t.Fatalf("expected filters to match update")
}
}

View File

@@ -0,0 +1,465 @@
package updatedetection
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/rs/zerolog"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
type errorReader struct {
err error
}
func (e *errorReader) Read(p []byte) (int, error) {
return 0, e.err
}
func newResponse(status int, headers http.Header, body io.Reader) *http.Response {
if headers == nil {
headers = http.Header{}
}
if body == nil {
body = bytes.NewReader(nil)
}
return &http.Response{
StatusCode: status,
Status: http.StatusText(status),
Header: headers,
Body: io.NopCloser(body),
}
}
func TestRegistryCheckerConfig(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
if r.httpClient == nil || r.cache == nil || r.configs == nil {
t.Fatalf("expected registry checker to initialize")
}
cfg := RegistryConfig{Host: "example.com", Username: "user", Password: "pass", Insecure: true}
r.AddRegistryConfig(cfg)
r.mu.RLock()
stored := r.configs["example.com"]
r.mu.RUnlock()
if stored.Username != "user" || !stored.Insecure {
t.Fatalf("expected config to be stored")
}
}
func TestRegistryCache(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.cacheDigest("key-digest", "sha256:abc")
r.cacheError("key-error", "boom")
if r.CacheSize() != 2 {
t.Fatalf("expected cache size 2, got %d", r.CacheSize())
}
if entry := r.getCached("key-digest"); entry == nil || entry.digest != "sha256:abc" {
t.Fatalf("expected cached digest")
}
if entry := r.getCached("missing"); entry != nil {
t.Fatalf("expected missing cache entry to be nil")
}
r.cache.entries["expired"] = cacheEntry{
digest: "old",
expiresAt: time.Now().Add(-time.Minute),
}
if entry := r.getCached("expired"); entry != nil {
t.Fatalf("expected expired cache entry to be nil")
}
r.CleanupCache()
if r.CacheSize() != 2 {
t.Fatalf("expected expired cache entry to be removed")
}
}
func TestRegistryCheckImageUpdate(t *testing.T) {
ctx := context.Background()
t.Run("DigestPinned", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
info, err := r.CheckImageUpdate(ctx, "nginx@sha256:abc", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Error == "" {
t.Fatalf("expected digest-pinned error")
}
})
t.Run("CachedError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
key := "registry-1.docker.io/library/nginx:latest"
r.cache.entries[key] = cacheEntry{
err: "cached error",
expiresAt: time.Now().Add(time.Hour),
}
info, err := r.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Error != "cached error" {
t.Fatalf("expected cached error")
}
})
t.Run("CachedDigest", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
key := "registry-1.docker.io/library/nginx:latest"
r.cache.entries[key] = cacheEntry{
digest: "sha256:new",
expiresAt: time.Now().Add(time.Hour),
}
info, err := r.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !info.UpdateAvailable || info.LatestDigest != "sha256:new" {
t.Fatalf("expected update available from cache")
}
})
t.Run("FetchErrorCaches", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusInternalServerError, nil, nil), nil
}),
}
info, err := r.CheckImageUpdate(ctx, "example.com/repo:tag", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Error == "" {
t.Fatalf("expected error from fetch")
}
if cached := r.getCached("example.com/repo:tag"); cached == nil || cached.err == "" {
t.Fatalf("expected error to be cached")
}
})
t.Run("FetchSuccessCaches", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
headers := http.Header{"Docker-Content-Digest": []string{"sha256:new"}}
return newResponse(http.StatusOK, headers, nil), nil
}),
}
info, err := r.CheckImageUpdate(ctx, "example.com/repo:tag", "sha256:old")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !info.UpdateAvailable {
t.Fatalf("expected update available")
}
if cached := r.getCached("example.com/repo:tag"); cached == nil || cached.digest == "" {
t.Fatalf("expected digest to be cached")
}
})
}
func TestFetchDigest(t *testing.T) {
ctx := context.Background()
t.Run("AuthError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("auth fail")
}),
}
if _, err := r.fetchDigest(ctx, "registry-1.docker.io", "library/nginx", "latest"); err == nil {
t.Fatalf("expected auth error")
}
})
t.Run("InsecureScheme", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.AddRegistryConfig(RegistryConfig{Host: "insecure.local", Insecure: true})
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "http" {
t.Fatalf("expected http scheme, got %q", req.URL.Scheme)
}
headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}}
return newResponse(http.StatusOK, headers, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "insecure.local", "repo", "latest"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
t.Run("RequestError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
if _, err := r.fetchDigest(ctx, "bad host", "repo", "latest"); err == nil {
t.Fatalf("expected request creation error")
}
})
t.Run("TokenHeader", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.AddRegistryConfig(RegistryConfig{Host: "ghcr.io", Password: "pat"})
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if got := req.Header.Get("Authorization"); got != "Bearer pat" {
t.Fatalf("expected bearer token, got %q", got)
}
headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}}
return newResponse(http.StatusOK, headers, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "ghcr.io", "owner/repo", "latest"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
t.Run("DoError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("network")
}),
}
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
t.Fatalf("expected request error")
}
})
t.Run("StatusUnauthorized", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusUnauthorized, nil, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
t.Fatalf("expected unauthorized error")
}
})
t.Run("StatusNotFound", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusNotFound, nil, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
t.Fatalf("expected not found error")
}
})
t.Run("StatusRateLimited", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusTooManyRequests, nil, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
t.Fatalf("expected rate limit error")
}
})
t.Run("StatusOtherError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusInternalServerError, nil, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
t.Fatalf("expected registry error")
}
})
t.Run("DigestHeader", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}}
return newResponse(http.StatusOK, headers, nil), nil
}),
}
digest, err := r.fetchDigest(ctx, "example.com", "repo", "latest")
if err != nil || digest != "sha256:abc" {
t.Fatalf("expected digest, got %q err %v", digest, err)
}
})
t.Run("EtagHeader", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
headers := http.Header{"Etag": []string{`"sha256:etag"`}}
return newResponse(http.StatusOK, headers, nil), nil
}),
}
digest, err := r.fetchDigest(ctx, "example.com", "repo", "latest")
if err != nil || digest != "sha256:etag" {
t.Fatalf("expected etag digest, got %q err %v", digest, err)
}
})
t.Run("NoDigest", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusOK, nil, nil), nil
}),
}
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
t.Fatalf("expected missing digest error")
}
})
}
func TestGetAuthToken(t *testing.T) {
ctx := context.Background()
t.Run("DockerHubSuccess", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.AddRegistryConfig(RegistryConfig{Host: "registry-1.docker.io", Username: "user", Password: "pass"})
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if !strings.HasPrefix(req.Header.Get("Authorization"), "Basic ") {
t.Fatalf("expected basic auth header")
}
body := `{"token":"tok"}`
return newResponse(http.StatusOK, nil, strings.NewReader(body)), nil
}),
}
token, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx")
if err != nil || token != "tok" {
t.Fatalf("expected token, got %q err %v", token, err)
}
})
t.Run("DockerHubRequestError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "bad\nrepo"); err == nil {
t.Fatalf("expected request error")
}
})
t.Run("DockerHubDoError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("network")
}),
}
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
t.Fatalf("expected network error")
}
})
t.Run("DockerHubStatusError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
return newResponse(http.StatusInternalServerError, nil, nil), nil
}),
}
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
t.Fatalf("expected status error")
}
})
t.Run("DockerHubReadError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
body := &errorReader{err: errors.New("read")}
return newResponse(http.StatusOK, nil, body), nil
}),
}
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
t.Fatalf("expected read error")
}
})
t.Run("DockerHubJSONError", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.httpClient = &http.Client{
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
body := "{"
return newResponse(http.StatusOK, nil, strings.NewReader(body)), nil
}),
}
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
t.Fatalf("expected json error")
}
})
t.Run("GHCRToken", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.AddRegistryConfig(RegistryConfig{Host: "ghcr.io", Password: "pat"})
token, err := r.getAuthToken(ctx, "ghcr.io", "owner/repo")
if err != nil || token != "pat" {
t.Fatalf("expected ghcr token")
}
})
t.Run("GHCRAnonymous", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
token, err := r.getAuthToken(ctx, "ghcr.io", "owner/repo")
if err != nil || token != "" {
t.Fatalf("expected empty ghcr token")
}
})
t.Run("OtherRegistryBasicAuth", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
r.AddRegistryConfig(RegistryConfig{Host: "example.com", Username: "user", Password: "pass"})
token, err := r.getAuthToken(ctx, "example.com", "repo")
if err != nil || token != "" {
t.Fatalf("expected empty token for basic auth registry")
}
})
t.Run("OtherRegistryNoAuth", func(t *testing.T) {
r := NewRegistryChecker(zerolog.Nop())
token, err := r.getAuthToken(ctx, "example.com", "repo")
if err != nil || token != "" {
t.Fatalf("expected empty token for registry without auth")
}
})
}

View File

@@ -0,0 +1,43 @@
package updatedetection
import "testing"
func TestStore_DeleteUpdateNotFound(t *testing.T) {
store := NewStore()
store.DeleteUpdate("missing")
if store.Count() != 0 {
t.Fatalf("expected empty store")
}
}
func TestStore_DeleteUpdatesForResourceMissing(t *testing.T) {
store := NewStore()
store.DeleteUpdatesForResource("missing")
if store.Count() != 0 {
t.Fatalf("expected empty store")
}
}
func TestStore_DeleteUpdatesForResourceNilUpdate(t *testing.T) {
store := NewStore()
store.byResource["res-1"] = "update-1"
store.updates["update-1"] = nil
store.DeleteUpdatesForResource("res-1")
if _, ok := store.byResource["res-1"]; ok {
t.Fatalf("expected byResource entry to be removed")
}
}
func TestStore_CountForHost(t *testing.T) {
store := NewStore()
store.UpsertUpdate(&UpdateInfo{ID: "update-1", ResourceID: "res-1", HostID: "host-1"})
store.UpsertUpdate(&UpdateInfo{ID: "update-2", ResourceID: "res-2", HostID: "host-1"})
if store.CountForHost("host-1") != 2 {
t.Fatalf("expected count 2 for host-1")
}
if store.CountForHost("missing") != 0 {
t.Fatalf("expected count 0 for missing host")
}
}