mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-18 00:17:39 +01:00
Improve internal package test coverage
This commit is contained in:
@@ -51,6 +51,25 @@ var requiredHostAgentBinaries = []HostAgentBinary{
|
||||
|
||||
var downloadMu sync.Mutex
|
||||
|
||||
var (
|
||||
httpClient = &http.Client{Timeout: 2 * time.Minute}
|
||||
downloadURLForVersion = func(version string) string {
|
||||
return fmt.Sprintf("https://github.com/rcourtman/Pulse/releases/download/%[1]s/pulse-%[1]s.tar.gz", version)
|
||||
}
|
||||
downloadAndInstallHostAgentBinariesFn = DownloadAndInstallHostAgentBinaries
|
||||
findMissingHostAgentBinariesFn = findMissingHostAgentBinaries
|
||||
mkdirAllFn = os.MkdirAll
|
||||
createTempFn = os.CreateTemp
|
||||
removeFn = os.Remove
|
||||
openFileFn = os.Open
|
||||
openFileModeFn = os.OpenFile
|
||||
renameFn = os.Rename
|
||||
symlinkFn = os.Symlink
|
||||
copyFn = io.Copy
|
||||
chmodFileFn = func(f *os.File, mode os.FileMode) error { return f.Chmod(mode) }
|
||||
closeFileFn = func(f *os.File) error { return f.Close() }
|
||||
)
|
||||
|
||||
// HostAgentSearchPaths returns the directories to search for host agent binaries.
|
||||
func HostAgentSearchPaths() []string {
|
||||
primary := strings.TrimSpace(os.Getenv("PULSE_BIN_DIR"))
|
||||
@@ -77,7 +96,7 @@ func HostAgentSearchPaths() []string {
|
||||
// The returned map contains any binaries that remain missing after the attempt.
|
||||
func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
|
||||
binDirs := HostAgentSearchPaths()
|
||||
missing := findMissingHostAgentBinaries(binDirs)
|
||||
missing := findMissingHostAgentBinariesFn(binDirs)
|
||||
if len(missing) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -86,7 +105,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
|
||||
defer downloadMu.Unlock()
|
||||
|
||||
// Re-check after acquiring the lock in case another goroutine restored them.
|
||||
missing = findMissingHostAgentBinaries(binDirs)
|
||||
missing = findMissingHostAgentBinariesFn(binDirs)
|
||||
if len(missing) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -101,7 +120,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
|
||||
Strs("missing_platforms", missingPlatforms).
|
||||
Msg("Host agent binaries missing - attempting to download bundle from GitHub release")
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries(version, binDirs[0]); err != nil {
|
||||
if err := downloadAndInstallHostAgentBinariesFn(version, binDirs[0]); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("target_dir", binDirs[0]).
|
||||
@@ -110,7 +129,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary {
|
||||
return missing
|
||||
}
|
||||
|
||||
if remaining := findMissingHostAgentBinaries(binDirs); len(remaining) > 0 {
|
||||
if remaining := findMissingHostAgentBinariesFn(binDirs); len(remaining) > 0 {
|
||||
stillMissing := make([]string, 0, len(remaining))
|
||||
for key := range remaining {
|
||||
stillMissing = append(stillMissing, key)
|
||||
@@ -133,19 +152,18 @@ func DownloadAndInstallHostAgentBinaries(version string, targetDir string) error
|
||||
return fmt.Errorf("cannot download host agent bundle for non-release version %q", version)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
if err := mkdirAllFn(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to ensure bin directory %s: %w", targetDir, err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://github.com/rcourtman/Pulse/releases/download/%[1]s/pulse-%[1]s.tar.gz", normalizedVersion)
|
||||
tempFile, err := os.CreateTemp("", "pulse-host-agent-*.tar.gz")
|
||||
url := downloadURLForVersion(normalizedVersion)
|
||||
tempFile, err := createTempFn("", "pulse-host-agent-*.tar.gz")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary archive file: %w", err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
defer removeFn(tempFile.Name())
|
||||
|
||||
client := &http.Client{Timeout: 2 * time.Minute}
|
||||
resp, err := client.Get(url)
|
||||
resp, err := httpClient.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download host agent bundle from %s: %w", url, err)
|
||||
}
|
||||
@@ -160,7 +178,7 @@ func DownloadAndInstallHostAgentBinaries(version string, targetDir string) error
|
||||
return fmt.Errorf("failed to save host agent bundle: %w", err)
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
if err := closeFileFn(tempFile); err != nil {
|
||||
return fmt.Errorf("failed to close temporary bundle file: %w", err)
|
||||
}
|
||||
|
||||
@@ -204,7 +222,7 @@ func normalizeVersionTag(version string) string {
|
||||
}
|
||||
|
||||
func extractHostAgentBinaries(archivePath, targetDir string) error {
|
||||
file, err := os.Open(archivePath)
|
||||
file, err := openFileFn(archivePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open host agent bundle: %w", err)
|
||||
}
|
||||
@@ -232,10 +250,6 @@ func extractHostAgentBinaries(archivePath, targetDir string) error {
|
||||
return fmt.Errorf("failed to read host agent bundle: %w", err)
|
||||
}
|
||||
|
||||
if header == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeSymlink {
|
||||
continue
|
||||
}
|
||||
@@ -265,10 +279,10 @@ func extractHostAgentBinaries(archivePath, targetDir string) error {
|
||||
}
|
||||
|
||||
for _, link := range symlinks {
|
||||
if err := os.Remove(link.path); err != nil && !os.IsNotExist(err) {
|
||||
if err := removeFn(link.path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to replace existing symlink %s: %w", link.path, err)
|
||||
}
|
||||
if err := os.Symlink(link.target, link.path); err != nil {
|
||||
if err := symlinkFn(link.target, link.path); err != nil {
|
||||
// Fallback: copy the referenced file if symlinks are not permitted
|
||||
source := filepath.Join(targetDir, link.target)
|
||||
if err := copyHostAgentFile(source, link.path); err != nil {
|
||||
@@ -281,31 +295,31 @@ func extractHostAgentBinaries(archivePath, targetDir string) error {
|
||||
}
|
||||
|
||||
func writeHostAgentFile(destination string, reader io.Reader, mode os.FileMode) error {
|
||||
if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil {
|
||||
if err := mkdirAllFn(filepath.Dir(destination), 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create directory for %s: %w", destination, err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(filepath.Dir(destination), "pulse-host-agent-*")
|
||||
tmpFile, err := createTempFn(filepath.Dir(destination), "pulse-host-agent-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary file for %s: %w", destination, err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
defer removeFn(tmpFile.Name())
|
||||
|
||||
if _, err := io.Copy(tmpFile, reader); err != nil {
|
||||
tmpFile.Close()
|
||||
if _, err := copyFn(tmpFile, reader); err != nil {
|
||||
closeFileFn(tmpFile)
|
||||
return fmt.Errorf("failed to extract %s: %w", destination, err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Chmod(normalizeExecutableMode(mode)); err != nil {
|
||||
tmpFile.Close()
|
||||
if err := chmodFileFn(tmpFile, normalizeExecutableMode(mode)); err != nil {
|
||||
closeFileFn(tmpFile)
|
||||
return fmt.Errorf("failed to set permissions on %s: %w", destination, err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
if err := closeFileFn(tmpFile); err != nil {
|
||||
return fmt.Errorf("failed to finalize %s: %w", destination, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpFile.Name(), destination); err != nil {
|
||||
if err := renameFn(tmpFile.Name(), destination); err != nil {
|
||||
return fmt.Errorf("failed to install %s: %w", destination, err)
|
||||
}
|
||||
|
||||
@@ -313,23 +327,23 @@ func writeHostAgentFile(destination string, reader io.Reader, mode os.FileMode)
|
||||
}
|
||||
|
||||
func copyHostAgentFile(source, destination string) error {
|
||||
src, err := os.Open(source)
|
||||
src, err := openFileFn(source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s for fallback copy: %w", source, err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil {
|
||||
if err := mkdirAllFn(filepath.Dir(destination), 0o755); err != nil {
|
||||
return fmt.Errorf("failed to prepare directory for %s: %w", destination, err)
|
||||
}
|
||||
|
||||
dst, err := os.OpenFile(destination, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
|
||||
dst, err := openFileModeFn(destination, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create fallback copy %s: %w", destination, err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
if _, err := copyFn(dst, src); err != nil {
|
||||
return fmt.Errorf("failed to copy %s to %s: %w", source, destination, err)
|
||||
}
|
||||
|
||||
|
||||
698
internal/agentbinaries/host_agent_coverage_test.go
Normal file
698
internal/agentbinaries/host_agent_coverage_test.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package agentbinaries
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
type tarEntry struct {
|
||||
name string
|
||||
body []byte
|
||||
mode int64
|
||||
typeflag byte
|
||||
linkname string
|
||||
}
|
||||
|
||||
type errorReader struct{}
|
||||
|
||||
func (e *errorReader) Read([]byte) (int, error) {
|
||||
return 0, errors.New("read fail")
|
||||
}
|
||||
|
||||
func buildTarGz(t *testing.T, entries []tarEntry) []byte {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gzw)
|
||||
|
||||
for _, entry := range entries {
|
||||
typeflag := entry.typeflag
|
||||
if typeflag == 0 {
|
||||
typeflag = tar.TypeReg
|
||||
}
|
||||
size := int64(len(entry.body))
|
||||
if typeflag != tar.TypeReg {
|
||||
size = 0
|
||||
}
|
||||
hdr := &tar.Header{
|
||||
Name: entry.name,
|
||||
Mode: entry.mode,
|
||||
Size: size,
|
||||
Typeflag: typeflag,
|
||||
Linkname: entry.linkname,
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
t.Fatalf("WriteHeader: %v", err)
|
||||
}
|
||||
if typeflag == tar.TypeReg {
|
||||
if _, err := tw.Write(entry.body); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
t.Fatalf("tar close: %v", err)
|
||||
}
|
||||
if err := gzw.Close(); err != nil {
|
||||
t.Fatalf("gzip close: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func saveHostAgentHooks() func() {
|
||||
origRequired := requiredHostAgentBinaries
|
||||
origDownloadFn := downloadAndInstallHostAgentBinariesFn
|
||||
origFindMissing := findMissingHostAgentBinariesFn
|
||||
origURL := downloadURLForVersion
|
||||
origClient := httpClient
|
||||
origMkdirAll := mkdirAllFn
|
||||
origCreateTemp := createTempFn
|
||||
origRemove := removeFn
|
||||
origOpen := openFileFn
|
||||
origOpenMode := openFileModeFn
|
||||
origRename := renameFn
|
||||
origSymlink := symlinkFn
|
||||
origCopy := copyFn
|
||||
origChmod := chmodFileFn
|
||||
origClose := closeFileFn
|
||||
|
||||
return func() {
|
||||
requiredHostAgentBinaries = origRequired
|
||||
downloadAndInstallHostAgentBinariesFn = origDownloadFn
|
||||
findMissingHostAgentBinariesFn = origFindMissing
|
||||
downloadURLForVersion = origURL
|
||||
httpClient = origClient
|
||||
mkdirAllFn = origMkdirAll
|
||||
createTempFn = origCreateTemp
|
||||
removeFn = origRemove
|
||||
openFileFn = origOpen
|
||||
openFileModeFn = origOpenMode
|
||||
renameFn = origRename
|
||||
symlinkFn = origSymlink
|
||||
copyFn = origCopy
|
||||
chmodFileFn = origChmod
|
||||
closeFileFn = origClose
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureHostAgentBinaries_NoMissing(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
requiredHostAgentBinaries = []HostAgentBinary{
|
||||
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64"), []byte("bin"), 0o755); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
|
||||
origEnv := os.Getenv("PULSE_BIN_DIR")
|
||||
t.Cleanup(func() {
|
||||
if origEnv == "" {
|
||||
os.Unsetenv("PULSE_BIN_DIR")
|
||||
} else {
|
||||
os.Setenv("PULSE_BIN_DIR", origEnv)
|
||||
}
|
||||
})
|
||||
os.Setenv("PULSE_BIN_DIR", tmpDir)
|
||||
|
||||
if missing := EnsureHostAgentBinaries("v1.0.0"); missing != nil {
|
||||
t.Fatalf("expected no missing binaries, got %v", missing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureHostAgentBinaries_DownloadError(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
requiredHostAgentBinaries = []HostAgentBinary{
|
||||
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
|
||||
}
|
||||
downloadAndInstallHostAgentBinariesFn = func(string, string) error {
|
||||
return errors.New("download failed")
|
||||
}
|
||||
|
||||
origEnv := os.Getenv("PULSE_BIN_DIR")
|
||||
t.Cleanup(func() {
|
||||
if origEnv == "" {
|
||||
os.Unsetenv("PULSE_BIN_DIR")
|
||||
} else {
|
||||
os.Setenv("PULSE_BIN_DIR", origEnv)
|
||||
}
|
||||
})
|
||||
os.Setenv("PULSE_BIN_DIR", tmpDir)
|
||||
|
||||
missing := EnsureHostAgentBinaries("v1.0.0")
|
||||
if len(missing) != 1 {
|
||||
t.Fatalf("expected missing map, got %v", missing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureHostAgentBinaries_StillMissing(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
requiredHostAgentBinaries = []HostAgentBinary{
|
||||
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
|
||||
}
|
||||
downloadAndInstallHostAgentBinariesFn = func(string, string) error { return nil }
|
||||
|
||||
origEnv := os.Getenv("PULSE_BIN_DIR")
|
||||
t.Cleanup(func() {
|
||||
if origEnv == "" {
|
||||
os.Unsetenv("PULSE_BIN_DIR")
|
||||
} else {
|
||||
os.Setenv("PULSE_BIN_DIR", origEnv)
|
||||
}
|
||||
})
|
||||
os.Setenv("PULSE_BIN_DIR", tmpDir)
|
||||
|
||||
missing := EnsureHostAgentBinaries("v1.0.0")
|
||||
if len(missing) != 1 {
|
||||
t.Fatalf("expected still missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureHostAgentBinaries_RecheckAfterLock(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
requiredHostAgentBinaries = []HostAgentBinary{
|
||||
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
|
||||
}
|
||||
|
||||
calls := 0
|
||||
findMissingHostAgentBinariesFn = func([]string) map[string]HostAgentBinary {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
return map[string]HostAgentBinary{
|
||||
"linux-amd64": requiredHostAgentBinaries[0],
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
origEnv := os.Getenv("PULSE_BIN_DIR")
|
||||
t.Cleanup(func() {
|
||||
if origEnv == "" {
|
||||
os.Unsetenv("PULSE_BIN_DIR")
|
||||
} else {
|
||||
os.Setenv("PULSE_BIN_DIR", origEnv)
|
||||
}
|
||||
})
|
||||
os.Setenv("PULSE_BIN_DIR", t.TempDir())
|
||||
|
||||
if result := EnsureHostAgentBinaries("v1.0.0"); result != nil {
|
||||
t.Fatalf("expected nil after recheck, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureHostAgentBinaries_RestoreSuccess(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
requiredHostAgentBinaries = []HostAgentBinary{
|
||||
{Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}},
|
||||
}
|
||||
downloadAndInstallHostAgentBinariesFn = func(string, string) error {
|
||||
return os.WriteFile(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64"), []byte("bin"), 0o755)
|
||||
}
|
||||
|
||||
origEnv := os.Getenv("PULSE_BIN_DIR")
|
||||
t.Cleanup(func() {
|
||||
if origEnv == "" {
|
||||
os.Unsetenv("PULSE_BIN_DIR")
|
||||
} else {
|
||||
os.Setenv("PULSE_BIN_DIR", origEnv)
|
||||
}
|
||||
})
|
||||
os.Setenv("PULSE_BIN_DIR", tmpDir)
|
||||
|
||||
if missing := EnsureHostAgentBinaries("v1.0.0"); missing != nil {
|
||||
t.Fatalf("expected restore success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndInstallHostAgentBinariesErrors(t *testing.T) {
|
||||
t.Run("MkdirAllError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected mkdir error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreateTempError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
createTempFn = func(string, string) (*os.File, error) { return nil, errors.New("temp fail") }
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected temp error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownloadError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network")
|
||||
}),
|
||||
}
|
||||
downloadURLForVersion = func(string) string { return "http://example/bundle.tar.gz" }
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected download error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("StatusError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
httpClient = server.Client()
|
||||
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected status error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CopyError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("data"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
httpClient = server.Client()
|
||||
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
|
||||
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected copy error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CopyReadError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
body := io.NopCloser(&errorReader{})
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Status: "200 OK",
|
||||
Body: body,
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
downloadURLForVersion = func(string) string { return "http://example/bundle.tar.gz" }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected copy read error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("data"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
httpClient = server.Client()
|
||||
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
|
||||
closeFileFn = func(*os.File) error { return errors.New("close fail") }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected close error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExtractError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("not a tarball"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
httpClient = server.Client()
|
||||
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected extract error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDownloadAndInstallHostAgentBinariesSuccess(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
payload := buildTarGz(t, []tarEntry{
|
||||
{name: "README.md", body: []byte("skip"), mode: 0o644},
|
||||
{name: "bin/", typeflag: tar.TypeDir, mode: 0o755},
|
||||
{name: "bin/not-agent.txt", body: []byte("skip"), mode: 0o644},
|
||||
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
|
||||
{name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"},
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write(payload)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
httpClient = server.Client()
|
||||
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
|
||||
symlinkFn = func(string, string) error { return errors.New("no symlink") }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", tmpDir); err != nil {
|
||||
t.Fatalf("expected success, got %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64")); err != nil {
|
||||
t.Fatalf("expected binary installed: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64.exe")); err != nil {
|
||||
t.Fatalf("expected symlink fallback copy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHostAgentBinariesErrors(t *testing.T) {
|
||||
t.Run("OpenError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
openFileFn = func(string) (*os.File, error) { return nil, errors.New("open fail") }
|
||||
if err := extractHostAgentBinaries("missing.tar.gz", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected open error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GzipError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmp := filepath.Join(t.TempDir(), "bad.gz")
|
||||
if err := os.WriteFile(tmp, []byte("bad"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
if err := extractHostAgentBinaries(tmp, t.TempDir()); err == nil {
|
||||
t.Fatalf("expected gzip error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TarReadError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
_, _ = gzw.Write([]byte("not a tar"))
|
||||
_ = gzw.Close()
|
||||
|
||||
tmp := filepath.Join(t.TempDir(), "bad.tar.gz")
|
||||
if err := os.WriteFile(tmp, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
if err := extractHostAgentBinaries(tmp, t.TempDir()); err == nil {
|
||||
t.Fatalf("expected tar read error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractHostAgentBinariesRemoveError(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
payload := buildTarGz(t, []tarEntry{
|
||||
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
|
||||
{name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"},
|
||||
})
|
||||
|
||||
archive := filepath.Join(t.TempDir(), "bundle.tar.gz")
|
||||
if err := os.WriteFile(archive, payload, 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
removeFn = func(string) error { return errors.New("remove fail") }
|
||||
if err := extractHostAgentBinaries(archive, tmpDir); err == nil {
|
||||
t.Fatalf("expected remove error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHostAgentBinariesSymlinkCopyError(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
payload := buildTarGz(t, []tarEntry{
|
||||
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
|
||||
{name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"},
|
||||
})
|
||||
archive := filepath.Join(t.TempDir(), "bundle.tar.gz")
|
||||
if err := os.WriteFile(archive, payload, 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
symlinkFn = func(string, string) error { return errors.New("no symlink") }
|
||||
openFileFn = func(path string) (*os.File, error) {
|
||||
if path == archive {
|
||||
return os.Open(path)
|
||||
}
|
||||
return nil, errors.New("open fail")
|
||||
}
|
||||
|
||||
if err := extractHostAgentBinaries(archive, tmpDir); err == nil {
|
||||
t.Fatalf("expected symlink fallback error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteHostAgentFileErrors(t *testing.T) {
|
||||
t.Run("MkdirAllError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
|
||||
if err := writeHostAgentFile("dest", strings.NewReader("data"), 0o644); err == nil {
|
||||
t.Fatalf("expected mkdir error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreateTempError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
createTempFn = func(string, string) (*os.File, error) { return nil, errors.New("temp fail") }
|
||||
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
|
||||
t.Fatalf("expected temp error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CopyError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
|
||||
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
|
||||
t.Fatalf("expected copy error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ChmodError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
chmodFileFn = func(*os.File, os.FileMode) error { return errors.New("chmod fail") }
|
||||
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
|
||||
t.Fatalf("expected chmod error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
closeFileFn = func(*os.File) error { return errors.New("close fail") }
|
||||
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
|
||||
t.Fatalf("expected close error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RenameError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
renameFn = func(string, string) error { return errors.New("rename fail") }
|
||||
if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil {
|
||||
t.Fatalf("expected rename error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteHostAgentFileSuccess(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
dest := filepath.Join(t.TempDir(), "pulse-host-agent-linux-amd64")
|
||||
if err := writeHostAgentFile(dest, strings.NewReader("data"), 0o644); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(dest); err != nil {
|
||||
t.Fatalf("expected file written: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyHostAgentFileErrors(t *testing.T) {
|
||||
t.Run("OpenError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
openFileFn = func(string) (*os.File, error) { return nil, errors.New("open fail") }
|
||||
if err := copyHostAgentFile("missing", filepath.Join(t.TempDir(), "dest")); err == nil {
|
||||
t.Fatalf("expected open error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MkdirAllError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
src := filepath.Join(t.TempDir(), "src")
|
||||
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
|
||||
if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil {
|
||||
t.Fatalf("expected mkdir error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenFileError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
src := filepath.Join(t.TempDir(), "src")
|
||||
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
openFileModeFn = func(string, int, os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("create fail")
|
||||
}
|
||||
if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil {
|
||||
t.Fatalf("expected open file error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CopyError", func(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
src := filepath.Join(t.TempDir(), "src")
|
||||
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
|
||||
if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil {
|
||||
t.Fatalf("expected copy error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopyHostAgentFileSuccess(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
src := filepath.Join(t.TempDir(), "src")
|
||||
if err := os.WriteFile(src, []byte("data"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
dest := filepath.Join(t.TempDir(), "dest")
|
||||
if err := copyHostAgentFile(src, dest); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(dest); err != nil {
|
||||
t.Fatalf("expected dest file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHostAgentBinariesWriteError(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
payload := buildTarGz(t, []tarEntry{
|
||||
{name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644},
|
||||
})
|
||||
archive := filepath.Join(t.TempDir(), "bundle.tar.gz")
|
||||
if err := os.WriteFile(archive, payload, 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") }
|
||||
if err := extractHostAgentBinaries(archive, tmpDir); err == nil {
|
||||
t.Fatalf("expected write error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndInstallHostAgentBinaries_Context(t *testing.T) {
|
||||
restore := saveHostAgentHooks()
|
||||
t.Cleanup(restore)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Context().Err() != nil {
|
||||
t.Fatalf("unexpected context error")
|
||||
}
|
||||
_, _ = w.Write([]byte("data"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
httpClient = server.Client()
|
||||
downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" }
|
||||
copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") }
|
||||
|
||||
if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil {
|
||||
t.Fatalf("expected copy error")
|
||||
}
|
||||
}
|
||||
@@ -20,19 +20,26 @@ var upgrader = websocket.Upgrader{
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
jsonMarshal = json.Marshal
|
||||
pingInterval = 5 * time.Second
|
||||
pingWriteWait = 5 * time.Second
|
||||
readFileTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Server manages WebSocket connections from agents
|
||||
type Server struct {
|
||||
mu sync.RWMutex
|
||||
agents map[string]*agentConn // agentID -> connection
|
||||
agents map[string]*agentConn // agentID -> connection
|
||||
pendingReqs map[string]chan CommandResultPayload // requestID -> response channel
|
||||
validateToken func(token string) bool
|
||||
}
|
||||
|
||||
type agentConn struct {
|
||||
conn *websocket.Conn
|
||||
agent ConnectedAgent
|
||||
writeMu sync.Mutex
|
||||
done chan struct{}
|
||||
conn *websocket.Conn
|
||||
agent ConnectedAgent
|
||||
writeMu sync.Mutex
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewServer creates a new agent execution server
|
||||
@@ -94,7 +101,7 @@ func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Parse registration payload
|
||||
payloadBytes, err := json.Marshal(msg.Payload)
|
||||
payloadBytes, err := jsonMarshal(msg.Payload)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal registration payload")
|
||||
conn.Close()
|
||||
@@ -258,7 +265,7 @@ func (s *Server) readLoop(ac *agentConn) {
|
||||
}
|
||||
|
||||
func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Track consecutive ping failures to detect dead connections faster
|
||||
@@ -273,7 +280,7 @@ func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
|
||||
return
|
||||
case <-ticker.C:
|
||||
ac.writeMu.Lock()
|
||||
err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
|
||||
err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(pingWriteWait))
|
||||
ac.writeMu.Unlock()
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
@@ -283,14 +290,14 @@ func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
|
||||
Str("hostname", ac.agent.Hostname).
|
||||
Int("consecutive_failures", consecutiveFailures).
|
||||
Msg("Failed to send ping to agent")
|
||||
|
||||
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
log.Error().
|
||||
Str("agent_id", ac.agent.AgentID).
|
||||
Str("hostname", ac.agent.Hostname).
|
||||
Int("failures", consecutiveFailures).
|
||||
Msg("Agent connection appears dead after multiple ping failures, closing connection")
|
||||
|
||||
|
||||
// Close the connection - this will cause readLoop to exit and clean up
|
||||
ac.conn.Close()
|
||||
return
|
||||
@@ -404,7 +411,7 @@ func (s *Server) ReadFile(ctx context.Context, agentID string, req ReadFilePaylo
|
||||
}
|
||||
|
||||
// Wait for result
|
||||
timeout := 30 * time.Second
|
||||
timeout := readFileTimeout
|
||||
select {
|
||||
case result := <-respCh:
|
||||
return &result, nil
|
||||
|
||||
537
internal/agentexec/server_coverage_test.go
Normal file
537
internal/agentexec/server_coverage_test.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package agentexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type noHijackResponseWriter struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (w *noHijackResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *noHijackResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (w *noHijackResponseWriter) WriteHeader(int) {}
|
||||
|
||||
func newConnPair(t *testing.T) (*websocket.Conn, *websocket.Conn, func()) {
|
||||
t.Helper()
|
||||
|
||||
serverConnCh := make(chan *websocket.Conn, 1)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
serverConnCh <- conn
|
||||
}))
|
||||
|
||||
clientConn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
ts.Close()
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
|
||||
var serverConn *websocket.Conn
|
||||
select {
|
||||
case serverConn = <-serverConnCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
clientConn.Close()
|
||||
ts.Close()
|
||||
t.Fatal("timed out waiting for server connection")
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
clientConn.Close()
|
||||
serverConn.Close()
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
return serverConn, clientConn, cleanup
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_UpgradeFailureAndDeadlineErrors(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example/ws", nil)
|
||||
s.HandleWebSocket(&noHijackResponseWriter{header: make(http.Header)}, req)
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_RegistrationReadError(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_RegistrationMessageJSONError(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte("{")); err != nil {
|
||||
t.Fatalf("WriteMessage: %v", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
if _, _, err := conn.ReadMessage(); err == nil {
|
||||
t.Fatalf("expected server to close on invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_RegistrationPayloadMarshalError(t *testing.T) {
|
||||
orig := jsonMarshal
|
||||
t.Cleanup(func() { jsonMarshal = orig })
|
||||
jsonMarshal = func(any) ([]byte, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
if _, _, err := conn.ReadMessage(); err == nil {
|
||||
t.Fatalf("expected server to close on marshal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_RegistrationPayloadUnmarshalError(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"agent_register","payload":"oops"}`)); err != nil {
|
||||
t.Fatalf("WriteMessage: %v", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
if _, _, err := conn.ReadMessage(); err == nil {
|
||||
t.Fatalf("expected server to close on invalid payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_PongHandler(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
_ = wsReadRegisteredPayload(t, conn)
|
||||
|
||||
if err := conn.WriteControl(websocket.PongMessage, []byte("pong"), time.Now().Add(time.Second)); err != nil {
|
||||
t.Fatalf("WriteControl pong: %v", err)
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
|
||||
}
|
||||
|
||||
func TestReadLoopDone(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
serverConn, clientConn, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
close(ac.done)
|
||||
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
s.readLoop(ac)
|
||||
|
||||
if s.IsAgentConnected("a1") {
|
||||
t.Fatalf("expected agent to be removed")
|
||||
}
|
||||
clientConn.Close()
|
||||
}
|
||||
|
||||
func TestReadLoopUnexpectedCloseError(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
serverConn, clientConn, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.readLoop(ac)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_ = clientConn.WriteControl(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseProtocolError, "bye"),
|
||||
time.Now().Add(time.Second),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("readLoop did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLoopCommandResultBranches(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
serverConn, clientConn, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.pendingReqs["req-full"] = make(chan CommandResultPayload)
|
||||
s.mu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.readLoop(ac)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_ = clientConn.WriteMessage(websocket.TextMessage, []byte("{"))
|
||||
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":123}}`))
|
||||
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-full","success":true}}`))
|
||||
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-missing","success":true}}`))
|
||||
|
||||
clientConn.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("readLoop did not exit")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
delete(s.pendingReqs, "req-full")
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestPingLoopSuccessAndStop(t *testing.T) {
|
||||
origInterval := pingInterval
|
||||
t.Cleanup(func() { pingInterval = origInterval })
|
||||
pingInterval = 5 * time.Millisecond
|
||||
|
||||
s := NewServer(nil)
|
||||
serverConn, _, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
exited := make(chan struct{})
|
||||
go func() {
|
||||
s.pingLoop(ac, stop)
|
||||
close(exited)
|
||||
}()
|
||||
|
||||
time.Sleep(2 * pingInterval)
|
||||
close(stop)
|
||||
|
||||
select {
|
||||
case <-exited:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("pingLoop did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingLoopFailuresClose(t *testing.T) {
|
||||
origInterval := pingInterval
|
||||
t.Cleanup(func() { pingInterval = origInterval })
|
||||
pingInterval = 5 * time.Millisecond
|
||||
|
||||
s := NewServer(nil)
|
||||
serverConn, _, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
serverConn.Close()
|
||||
|
||||
stop := make(chan struct{})
|
||||
exited := make(chan struct{})
|
||||
go func() {
|
||||
s.pingLoop(ac, stop)
|
||||
close(exited)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-exited:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("pingLoop did not exit after failures")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessageMarshalError(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
if err := s.sendMessage(nil, Message{Payload: make(chan int)}); err == nil {
|
||||
t.Fatalf("expected marshal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommandSendError(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
serverConn, _, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
serverConn.Close()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r1", Timeout: 1})
|
||||
if err == nil {
|
||||
t.Fatalf("expected send error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommandTimeoutAndCancel(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
serverConn, _, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-timeout", Timeout: 1})
|
||||
if err == nil || !strings.Contains(err.Error(), "timed out") {
|
||||
t.Fatalf("expected timeout error, got %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{RequestID: "r-cancel", Timeout: 1})
|
||||
if err == nil {
|
||||
t.Fatalf("expected cancel error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommandDefaultTimeout(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
serverConn, _, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
s.mu.RLock()
|
||||
ch := s.pendingReqs["r-default"]
|
||||
s.mu.RUnlock()
|
||||
if ch != nil {
|
||||
ch <- CommandResultPayload{RequestID: "r-default", Success: true}
|
||||
return
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
result, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-default"})
|
||||
if err != nil || result == nil || !result.Success {
|
||||
t.Fatalf("expected success, got result=%v err=%v", result, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileRoundTrip(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
_ = wsReadRegisteredPayload(t, conn)
|
||||
|
||||
agentDone := make(chan error, 1)
|
||||
go func() {
|
||||
for {
|
||||
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
|
||||
if err != nil {
|
||||
agentDone <- err
|
||||
return
|
||||
}
|
||||
if msg.Type != MsgTypeReadFile || msg.Payload == nil {
|
||||
continue
|
||||
}
|
||||
var payload ReadFilePayload
|
||||
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
|
||||
agentDone <- err
|
||||
return
|
||||
}
|
||||
agentDone <- conn.WriteJSON(Message{
|
||||
Type: MsgTypeCommandResult,
|
||||
Timestamp: time.Now(),
|
||||
Payload: CommandResultPayload{
|
||||
RequestID: payload.RequestID,
|
||||
Success: true,
|
||||
Stdout: "data",
|
||||
ExitCode: 0,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-1", Path: "/etc/hosts"})
|
||||
if err != nil || result == nil || result.Stdout != "data" {
|
||||
t.Fatalf("unexpected read file result=%v err=%v", result, err)
|
||||
}
|
||||
|
||||
if err := <-agentDone; err != nil {
|
||||
t.Fatalf("agent error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileTimeoutCancelAndSendError(t *testing.T) {
|
||||
origTimeout := readFileTimeout
|
||||
t.Cleanup(func() { readFileTimeout = origTimeout })
|
||||
readFileTimeout = 10 * time.Millisecond
|
||||
|
||||
s := NewServer(nil)
|
||||
serverConn, _, cleanup := newConnPair(t)
|
||||
defer cleanup()
|
||||
|
||||
ac := &agentConn{
|
||||
conn: serverConn,
|
||||
agent: ConnectedAgent{AgentID: "a1"},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.agents["a1"] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-timeout"}); err == nil {
|
||||
t.Fatalf("expected timeout error")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
if _, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-cancel"}); err == nil {
|
||||
t.Fatalf("expected cancel error")
|
||||
}
|
||||
|
||||
serverConn.Close()
|
||||
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-send"}); err == nil {
|
||||
t.Fatalf("expected send error")
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,9 @@ func TestConnectedAgentLookups(t *testing.T) {
|
||||
if !ok || agentID != "a2" {
|
||||
t.Fatalf("expected GetAgentForHost(host2) = (a2, true), got (%q, %v)", agentID, ok)
|
||||
}
|
||||
if _, ok := s.GetAgentForHost("missing"); ok {
|
||||
t.Fatalf("expected missing host to return false")
|
||||
}
|
||||
|
||||
agents := s.GetConnectedAgents()
|
||||
if len(agents) != 2 {
|
||||
|
||||
1033
internal/agentupdate/coverage_test.go
Normal file
1033
internal/agentupdate/coverage_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,13 +8,15 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var execFn = syscall.Exec
|
||||
|
||||
// restartProcess replaces the current process with a new instance.
|
||||
// On Unix-like systems, this uses syscall.Exec for an in-place restart.
|
||||
func restartProcess(execPath string) error {
|
||||
args := os.Args
|
||||
env := os.Environ()
|
||||
|
||||
if err := syscall.Exec(execPath, args, env); err != nil {
|
||||
if err := execFn(execPath, args, env); err != nil {
|
||||
return fmt.Errorf("failed to restart: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,24 @@ const (
|
||||
downloadTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
maxBinarySizeBytes int64 = maxBinarySize
|
||||
runtimeGOOS = runtime.GOOS
|
||||
runtimeGOARCH = runtime.GOARCH
|
||||
unameCommand = func() ([]byte, error) { return exec.Command("uname", "-m").Output() }
|
||||
unraidVersionPath = "/etc/unraid-version"
|
||||
unraidPersistentPathFn = unraidPersistentPath
|
||||
restartProcessFn = restartProcess
|
||||
osExecutableFn = os.Executable
|
||||
evalSymlinksFn = filepath.EvalSymlinks
|
||||
createTempFn = os.CreateTemp
|
||||
chmodFn = os.Chmod
|
||||
renameFn = os.Rename
|
||||
closeFileFn = func(f *os.File) error { return f.Close() }
|
||||
readFileFn = os.ReadFile
|
||||
writeFileFn = os.WriteFile
|
||||
)
|
||||
|
||||
// Config holds the configuration for the updater.
|
||||
type Config struct {
|
||||
// PulseURL is the base URL of the Pulse server (e.g., "https://pulse.example.com:7655")
|
||||
@@ -66,6 +84,8 @@ type Updater struct {
|
||||
logger zerolog.Logger
|
||||
|
||||
performUpdateFn func(context.Context) error
|
||||
initialDelay time.Duration
|
||||
newTicker func(time.Duration) *time.Ticker
|
||||
}
|
||||
|
||||
// New creates a new Updater with the given configuration.
|
||||
@@ -95,6 +115,8 @@ func New(cfg Config) *Updater {
|
||||
logger: logger,
|
||||
}
|
||||
u.performUpdateFn = u.performUpdate
|
||||
u.initialDelay = 5 * time.Second
|
||||
u.newTicker = time.NewTicker
|
||||
return u
|
||||
}
|
||||
|
||||
@@ -114,11 +136,11 @@ func (u *Updater) RunLoop(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(u.initialDelay):
|
||||
u.CheckAndUpdate(ctx)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(u.cfg.CheckInterval)
|
||||
ticker := u.newTicker(u.cfg.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -228,7 +250,7 @@ func (u *Updater) getServerVersion(ctx context.Context) (string, error) {
|
||||
|
||||
// isUnraid checks if we're running on Unraid by looking for /etc/unraid-version
|
||||
func isUnraid() bool {
|
||||
_, err := os.Stat("/etc/unraid-version")
|
||||
_, err := os.Stat(unraidVersionPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -245,7 +267,7 @@ func verifyBinaryMagic(path string) error {
|
||||
return fmt.Errorf("failed to read magic bytes: %w", err)
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
switch runtimeGOOS {
|
||||
case "linux":
|
||||
// ELF magic: 0x7f 'E' 'L' 'F'
|
||||
if magic[0] == 0x7f && magic[1] == 'E' && magic[2] == 'L' && magic[3] == 'F' {
|
||||
@@ -286,10 +308,14 @@ func unraidPersistentPath(agentName string) string {
|
||||
|
||||
// performUpdate downloads and installs the new agent binary.
|
||||
func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
execPath, err := os.Executable()
|
||||
execPath, err := osExecutableFn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executable path: %w", err)
|
||||
}
|
||||
return u.performUpdateWithExecPath(ctx, execPath)
|
||||
}
|
||||
|
||||
func (u *Updater) performUpdateWithExecPath(ctx context.Context, execPath string) error {
|
||||
|
||||
// Build download URL
|
||||
downloadBase := fmt.Sprintf("%s/download/%s", strings.TrimRight(u.cfg.PulseURL, "/"), u.cfg.AgentName)
|
||||
@@ -303,7 +329,7 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
candidates = append(candidates, downloadBase)
|
||||
|
||||
var resp *http.Response
|
||||
var lastErr error
|
||||
lastErr := errors.New("failed to download binary")
|
||||
|
||||
for _, url := range candidates {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
@@ -335,9 +361,6 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("failed to download binary")
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -346,7 +369,7 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
checksumHeader := strings.TrimSpace(resp.Header.Get("X-Checksum-Sha256"))
|
||||
|
||||
// Resolve symlinks to get the real path for atomic rename
|
||||
realExecPath, err := filepath.EvalSymlinks(execPath)
|
||||
realExecPath, err := evalSymlinksFn(execPath)
|
||||
if err != nil {
|
||||
// Fall back to original path if symlink resolution fails
|
||||
realExecPath = execPath
|
||||
@@ -355,7 +378,7 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
// Create temporary file in the same directory as the target binary
|
||||
// to ensure atomic rename works (os.Rename fails across filesystems)
|
||||
targetDir := filepath.Dir(realExecPath)
|
||||
tmpFile, err := os.CreateTemp(targetDir, u.cfg.AgentName+"-*.tmp")
|
||||
tmpFile, err := createTempFn(targetDir, u.cfg.AgentName+"-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
@@ -364,17 +387,17 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
|
||||
// Write downloaded binary with checksum calculation and size limit
|
||||
hasher := sha256.New()
|
||||
limitedReader := io.LimitReader(resp.Body, maxBinarySize+1) // +1 to detect overflow
|
||||
limitedReader := io.LimitReader(resp.Body, maxBinarySizeBytes+1) // +1 to detect overflow
|
||||
written, err := io.Copy(tmpFile, io.TeeReader(limitedReader, hasher))
|
||||
if err != nil {
|
||||
tmpFile.Close()
|
||||
closeFileFn(tmpFile)
|
||||
return fmt.Errorf("failed to write binary: %w", err)
|
||||
}
|
||||
if written > maxBinarySize {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySize)
|
||||
if written > maxBinarySizeBytes {
|
||||
closeFileFn(tmpFile)
|
||||
return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySizeBytes)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
if err := closeFileFn(tmpFile); err != nil {
|
||||
return fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
@@ -397,19 +420,19 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
u.logger.Debug().Str("checksum", downloadChecksum).Msg("Checksum verified")
|
||||
|
||||
// Make executable
|
||||
if err := os.Chmod(tmpPath, 0755); err != nil {
|
||||
if err := chmodFn(tmpPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to chmod: %w", err)
|
||||
}
|
||||
|
||||
// Atomic replacement with backup (use realExecPath for rename operations)
|
||||
backupPath := realExecPath + ".backup"
|
||||
if err := os.Rename(realExecPath, backupPath); err != nil {
|
||||
if err := renameFn(realExecPath, backupPath); err != nil {
|
||||
return fmt.Errorf("failed to backup current binary: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, realExecPath); err != nil {
|
||||
if err := renameFn(tmpPath, realExecPath); err != nil {
|
||||
// Restore backup on failure
|
||||
os.Rename(backupPath, realExecPath)
|
||||
renameFn(backupPath, realExecPath)
|
||||
return fmt.Errorf("failed to replace binary: %w", err)
|
||||
}
|
||||
|
||||
@@ -418,26 +441,26 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
|
||||
// Write previous version to a file so the agent can report "updated from X" on next start
|
||||
updateInfoPath := filepath.Join(targetDir, ".pulse-update-info")
|
||||
_ = os.WriteFile(updateInfoPath, []byte(u.cfg.CurrentVersion), 0644)
|
||||
_ = writeFileFn(updateInfoPath, []byte(u.cfg.CurrentVersion), 0644)
|
||||
|
||||
// On Unraid, also update the persistent copy on the flash drive
|
||||
// This ensures the update survives reboots
|
||||
if isUnraid() {
|
||||
persistPath := unraidPersistentPath(u.cfg.AgentName)
|
||||
persistPath := unraidPersistentPathFn(u.cfg.AgentName)
|
||||
if _, err := os.Stat(persistPath); err == nil {
|
||||
// Persistent path exists, update it
|
||||
u.logger.Debug().Str("path", persistPath).Msg("Updating Unraid persistent binary")
|
||||
|
||||
// Read the newly installed binary
|
||||
newBinary, err := os.ReadFile(execPath)
|
||||
newBinary, err := readFileFn(execPath)
|
||||
if err != nil {
|
||||
u.logger.Warn().Err(err).Msg("Failed to read new binary for Unraid persistence")
|
||||
} else {
|
||||
// Write to persistent storage (atomic via temp file)
|
||||
tmpPersist := persistPath + ".tmp"
|
||||
if err := os.WriteFile(tmpPersist, newBinary, 0644); err != nil {
|
||||
if err := writeFileFn(tmpPersist, newBinary, 0644); err != nil {
|
||||
u.logger.Warn().Err(err).Msg("Failed to write Unraid persistent binary")
|
||||
} else if err := os.Rename(tmpPersist, persistPath); err != nil {
|
||||
} else if err := renameFn(tmpPersist, persistPath); err != nil {
|
||||
u.logger.Warn().Err(err).Msg("Failed to rename Unraid persistent binary")
|
||||
os.Remove(tmpPersist)
|
||||
} else {
|
||||
@@ -448,19 +471,19 @@ func (u *Updater) performUpdate(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Restart the process using platform-specific implementation
|
||||
return restartProcess(execPath)
|
||||
return restartProcessFn(execPath)
|
||||
}
|
||||
|
||||
// GetUpdatedFromVersion checks if the agent was recently updated and returns the previous version.
|
||||
// Returns empty string if no update info exists. Clears the info file after reading.
|
||||
func GetUpdatedFromVersion() string {
|
||||
execPath, err := os.Executable()
|
||||
execPath, err := osExecutableFn()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Resolve symlinks to get the real path
|
||||
realExecPath, err := filepath.EvalSymlinks(execPath)
|
||||
realExecPath, err := evalSymlinksFn(execPath)
|
||||
if err != nil {
|
||||
realExecPath = execPath
|
||||
}
|
||||
@@ -479,8 +502,8 @@ func GetUpdatedFromVersion() string {
|
||||
|
||||
// determineArch returns the architecture string for download URLs (e.g., "linux-amd64", "darwin-arm64").
|
||||
func determineArch() string {
|
||||
os := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
os := runtimeGOOS
|
||||
arch := runtimeGOARCH
|
||||
|
||||
// Normalize architecture
|
||||
switch arch {
|
||||
@@ -497,7 +520,7 @@ func determineArch() string {
|
||||
}
|
||||
|
||||
// Fall back to uname for edge cases on unknown OS
|
||||
out, err := exec.Command("uname", "-m").Output()
|
||||
out, err := unameCommand()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -40,68 +40,68 @@ const (
|
||||
// HealthFactor represents a single component affecting health
|
||||
type HealthFactor struct {
|
||||
Name string `json:"name"`
|
||||
Impact float64 `json:"impact"` // -1 to 1, negative is bad
|
||||
Impact float64 `json:"impact"` // -1 to 1, negative is bad
|
||||
Description string `json:"description"`
|
||||
Category string `json:"category"` // finding, prediction, baseline, incident
|
||||
}
|
||||
|
||||
// HealthScore represents the overall health of a resource or system
|
||||
type HealthScore struct {
|
||||
Score float64 `json:"score"` // 0-100
|
||||
Grade HealthGrade `json:"grade"` // A, B, C, D, F
|
||||
Trend HealthTrend `json:"trend"` // improving, stable, declining
|
||||
Score float64 `json:"score"` // 0-100
|
||||
Grade HealthGrade `json:"grade"` // A, B, C, D, F
|
||||
Trend HealthTrend `json:"trend"` // improving, stable, declining
|
||||
Factors []HealthFactor `json:"factors"` // What's affecting the score
|
||||
Prediction string `json:"prediction"` // Human-readable outlook
|
||||
Prediction string `json:"prediction"` // Human-readable outlook
|
||||
}
|
||||
|
||||
// ResourceIntelligence aggregates all AI knowledge about a single resource
|
||||
type ResourceIntelligence struct {
|
||||
ResourceID string `json:"resource_id"`
|
||||
ResourceName string `json:"resource_name,omitempty"`
|
||||
ResourceType string `json:"resource_type,omitempty"`
|
||||
Health HealthScore `json:"health"`
|
||||
ActiveFindings []*Finding `json:"active_findings,omitempty"`
|
||||
Predictions []patterns.FailurePrediction `json:"predictions,omitempty"`
|
||||
Dependencies []string `json:"dependencies,omitempty"` // Resources this depends on
|
||||
Dependents []string `json:"dependents,omitempty"` // Resources that depend on this
|
||||
Correlations []*correlation.Correlation `json:"correlations,omitempty"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
ResourceName string `json:"resource_name,omitempty"`
|
||||
ResourceType string `json:"resource_type,omitempty"`
|
||||
Health HealthScore `json:"health"`
|
||||
ActiveFindings []*Finding `json:"active_findings,omitempty"`
|
||||
Predictions []patterns.FailurePrediction `json:"predictions,omitempty"`
|
||||
Dependencies []string `json:"dependencies,omitempty"` // Resources this depends on
|
||||
Dependents []string `json:"dependents,omitempty"` // Resources that depend on this
|
||||
Correlations []*correlation.Correlation `json:"correlations,omitempty"`
|
||||
Baselines map[string]*baseline.FlatBaseline `json:"baselines,omitempty"`
|
||||
Anomalies []AnomalyReport `json:"anomalies,omitempty"`
|
||||
RecentIncidents []*memory.Incident `json:"recent_incidents,omitempty"`
|
||||
Knowledge *knowledge.GuestKnowledge `json:"knowledge,omitempty"`
|
||||
NoteCount int `json:"note_count"`
|
||||
Anomalies []AnomalyReport `json:"anomalies,omitempty"`
|
||||
RecentIncidents []*memory.Incident `json:"recent_incidents,omitempty"`
|
||||
Knowledge *knowledge.GuestKnowledge `json:"knowledge,omitempty"`
|
||||
NoteCount int `json:"note_count"`
|
||||
}
|
||||
|
||||
// AnomalyReport describes a metric that's deviating from baseline
|
||||
type AnomalyReport struct {
|
||||
Metric string `json:"metric"`
|
||||
CurrentValue float64 `json:"current_value"`
|
||||
BaselineMean float64 `json:"baseline_mean"`
|
||||
ZScore float64 `json:"z_score"`
|
||||
Metric string `json:"metric"`
|
||||
CurrentValue float64 `json:"current_value"`
|
||||
BaselineMean float64 `json:"baseline_mean"`
|
||||
ZScore float64 `json:"z_score"`
|
||||
Severity baseline.AnomalySeverity `json:"severity"`
|
||||
Description string `json:"description"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// IntelligenceSummary provides a system-wide intelligence overview
|
||||
type IntelligenceSummary struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
OverallHealth HealthScore `json:"overall_health"`
|
||||
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
OverallHealth HealthScore `json:"overall_health"`
|
||||
|
||||
// Findings summary
|
||||
FindingsCount FindingsCounts `json:"findings_count"`
|
||||
TopFindings []*Finding `json:"top_findings,omitempty"` // Most critical
|
||||
|
||||
FindingsCount FindingsCounts `json:"findings_count"`
|
||||
TopFindings []*Finding `json:"top_findings,omitempty"` // Most critical
|
||||
|
||||
// Predictions
|
||||
PredictionsCount int `json:"predictions_count"`
|
||||
UpcomingRisks []patterns.FailurePrediction `json:"upcoming_risks,omitempty"`
|
||||
|
||||
PredictionsCount int `json:"predictions_count"`
|
||||
UpcomingRisks []patterns.FailurePrediction `json:"upcoming_risks,omitempty"`
|
||||
|
||||
// Recent activity
|
||||
RecentChangesCount int `json:"recent_changes_count"`
|
||||
RecentRemediations []memory.RemediationRecord `json:"recent_remediations,omitempty"`
|
||||
|
||||
RecentChangesCount int `json:"recent_changes_count"`
|
||||
RecentRemediations []memory.RemediationRecord `json:"recent_remediations,omitempty"`
|
||||
|
||||
// Learning progress
|
||||
Learning LearningStats `json:"learning"`
|
||||
|
||||
|
||||
// Resources needing attention
|
||||
ResourcesAtRisk []ResourceRiskSummary `json:"resources_at_risk,omitempty"`
|
||||
}
|
||||
@@ -137,7 +137,7 @@ type ResourceRiskSummary struct {
|
||||
// Intelligence orchestrates all AI subsystems into a unified system
|
||||
type Intelligence struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
|
||||
// Core subsystems
|
||||
findings *FindingsStore
|
||||
patterns *patterns.Detector
|
||||
@@ -147,10 +147,13 @@ type Intelligence struct {
|
||||
knowledge *knowledge.Store
|
||||
changes *memory.ChangeDetector
|
||||
remediations *memory.RemediationLog
|
||||
|
||||
|
||||
// State access
|
||||
stateProvider StateProvider
|
||||
|
||||
|
||||
// Optional hook for anomaly detection (used by patrol integration/tests)
|
||||
anomalyDetector func(resourceID string) []AnomalyReport
|
||||
|
||||
// Configuration
|
||||
dataDir string
|
||||
}
|
||||
@@ -180,7 +183,7 @@ func (i *Intelligence) SetSubsystems(
|
||||
) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
|
||||
i.findings = findings
|
||||
i.patterns = patternsDetector
|
||||
i.correlations = correlationsDetector
|
||||
@@ -202,45 +205,45 @@ func (i *Intelligence) SetStateProvider(sp StateProvider) {
|
||||
func (i *Intelligence) GetSummary() *IntelligenceSummary {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
|
||||
|
||||
summary := &IntelligenceSummary{
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
// Aggregate findings
|
||||
if i.findings != nil {
|
||||
all := i.findings.GetActive(FindingSeverityInfo) // Get all active findings
|
||||
summary.FindingsCount = i.countFindings(all)
|
||||
summary.TopFindings = i.getTopFindings(all, 5)
|
||||
}
|
||||
|
||||
|
||||
// Aggregate predictions
|
||||
if i.patterns != nil {
|
||||
predictions := i.patterns.GetPredictions()
|
||||
summary.PredictionsCount = len(predictions)
|
||||
summary.UpcomingRisks = i.getUpcomingRisks(predictions, 5)
|
||||
}
|
||||
|
||||
|
||||
// Aggregate recent activity
|
||||
if i.changes != nil {
|
||||
recent := i.changes.GetRecentChanges(100, time.Now().Add(-24*time.Hour))
|
||||
summary.RecentChangesCount = len(recent)
|
||||
}
|
||||
|
||||
|
||||
if i.remediations != nil {
|
||||
recent := i.remediations.GetRecentRemediations(5, time.Now().Add(-24*time.Hour))
|
||||
summary.RecentRemediations = recent
|
||||
}
|
||||
|
||||
|
||||
// Learning stats
|
||||
summary.Learning = i.getLearningStats()
|
||||
|
||||
|
||||
// Calculate overall health
|
||||
summary.OverallHealth = i.calculateOverallHealth(summary)
|
||||
|
||||
|
||||
// Resources at risk
|
||||
summary.ResourcesAtRisk = i.getResourcesAtRisk(5)
|
||||
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
@@ -248,11 +251,11 @@ func (i *Intelligence) GetSummary() *IntelligenceSummary {
|
||||
func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntelligence {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
|
||||
|
||||
intel := &ResourceIntelligence{
|
||||
ResourceID: resourceID,
|
||||
}
|
||||
|
||||
|
||||
// Active findings
|
||||
if i.findings != nil {
|
||||
intel.ActiveFindings = i.findings.GetByResource(resourceID)
|
||||
@@ -261,19 +264,19 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
|
||||
intel.ResourceType = intel.ActiveFindings[0].ResourceType
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Predictions
|
||||
if i.patterns != nil {
|
||||
intel.Predictions = i.patterns.GetPredictionsForResource(resourceID)
|
||||
}
|
||||
|
||||
|
||||
// Correlations and dependencies
|
||||
if i.correlations != nil {
|
||||
intel.Correlations = i.correlations.GetCorrelationsForResource(resourceID)
|
||||
intel.Dependencies = i.correlations.GetDependsOn(resourceID)
|
||||
intel.Dependents = i.correlations.GetDependencies(resourceID)
|
||||
}
|
||||
|
||||
|
||||
// Baselines
|
||||
if i.baselines != nil {
|
||||
if rb, ok := i.baselines.GetResourceBaseline(resourceID); ok {
|
||||
@@ -290,12 +293,12 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Recent incidents
|
||||
if i.incidents != nil {
|
||||
intel.RecentIncidents = i.incidents.ListIncidentsByResource(resourceID, 5)
|
||||
}
|
||||
|
||||
|
||||
// Knowledge
|
||||
if i.knowledge != nil {
|
||||
if k, err := i.knowledge.GetKnowledge(resourceID); err == nil && k != nil {
|
||||
@@ -309,10 +312,10 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Calculate health score
|
||||
intel.Health = i.calculateResourceHealth(intel)
|
||||
|
||||
|
||||
return intel
|
||||
}
|
||||
|
||||
@@ -320,49 +323,49 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel
|
||||
func (i *Intelligence) FormatContext(resourceID string) string {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
|
||||
|
||||
var sections []string
|
||||
|
||||
|
||||
// Knowledge (most important - what we've learned)
|
||||
if i.knowledge != nil {
|
||||
if ctx := i.knowledge.FormatForContext(resourceID); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Baselines (what's normal for this resource)
|
||||
if i.baselines != nil {
|
||||
if ctx := i.formatBaselinesForContext(resourceID); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Current anomalies
|
||||
if anomalies := i.detectCurrentAnomalies(resourceID); len(anomalies) > 0 {
|
||||
sections = append(sections, i.formatAnomaliesForContext(anomalies))
|
||||
}
|
||||
|
||||
|
||||
// Patterns/Predictions
|
||||
if i.patterns != nil {
|
||||
if ctx := i.patterns.FormatForContext(resourceID); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Correlations
|
||||
if i.correlations != nil {
|
||||
if ctx := i.correlations.FormatForContext(resourceID); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Incidents
|
||||
if i.incidents != nil {
|
||||
if ctx := i.incidents.FormatForResource(resourceID, 5); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return strings.Join(sections, "\n")
|
||||
}
|
||||
|
||||
@@ -370,37 +373,37 @@ func (i *Intelligence) FormatContext(resourceID string) string {
|
||||
func (i *Intelligence) FormatGlobalContext() string {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
|
||||
|
||||
var sections []string
|
||||
|
||||
|
||||
// All saved knowledge (limited)
|
||||
if i.knowledge != nil {
|
||||
if ctx := i.knowledge.FormatAllForContext(); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Recent incidents across infrastructure
|
||||
if i.incidents != nil {
|
||||
if ctx := i.incidents.FormatForPatrol(8); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Top correlations
|
||||
if i.correlations != nil {
|
||||
if ctx := i.correlations.FormatForContext(""); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Top predictions
|
||||
if i.patterns != nil {
|
||||
if ctx := i.patterns.FormatForContext(""); ctx != "" {
|
||||
sections = append(sections, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return strings.Join(sections, "\n")
|
||||
}
|
||||
|
||||
@@ -408,11 +411,11 @@ func (i *Intelligence) FormatGlobalContext() string {
|
||||
func (i *Intelligence) RecordLearning(resourceID, resourceName, resourceType, title, content string) error {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
|
||||
|
||||
if i.knowledge == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
return i.knowledge.SaveNote(resourceID, resourceName, resourceType, "learning", title, content)
|
||||
}
|
||||
|
||||
@@ -420,11 +423,11 @@ func (i *Intelligence) RecordLearning(resourceID, resourceName, resourceType, ti
|
||||
func (i *Intelligence) CheckBaselinesForResource(resourceID string, metrics map[string]float64) []AnomalyReport {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
|
||||
|
||||
if i.baselines == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
var anomalies []AnomalyReport
|
||||
for metric, value := range metrics {
|
||||
severity, zScore, bl := i.baselines.CheckAnomaly(resourceID, metric, value)
|
||||
@@ -439,7 +442,7 @@ func (i *Intelligence) CheckBaselinesForResource(resourceID string, metrics map[
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return anomalies
|
||||
}
|
||||
|
||||
@@ -452,7 +455,7 @@ func (i *Intelligence) CreatePredictionFinding(pred patterns.FailurePrediction)
|
||||
if pred.Confidence > 0.8 && pred.DaysUntil < 1 {
|
||||
severity = FindingSeverityCritical
|
||||
}
|
||||
|
||||
|
||||
return &Finding{
|
||||
ID: fmt.Sprintf("pred-%s-%s", pred.ResourceID, pred.EventType),
|
||||
Key: fmt.Sprintf("prediction:%s:%s", pred.ResourceID, pred.EventType),
|
||||
@@ -493,7 +496,7 @@ func (i *Intelligence) getTopFindings(findings []*Finding, limit int) []*Finding
|
||||
if len(findings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Sort by severity (critical first) then by detection time (newest first)
|
||||
sorted := make([]*Finding, len(findings))
|
||||
copy(sorted, findings)
|
||||
@@ -505,7 +508,7 @@ func (i *Intelligence) getTopFindings(findings []*Finding, limit int) []*Finding
|
||||
}
|
||||
return sorted[a].DetectedAt.After(sorted[b].DetectedAt)
|
||||
})
|
||||
|
||||
|
||||
if len(sorted) > limit {
|
||||
sorted = sorted[:limit]
|
||||
}
|
||||
@@ -531,7 +534,7 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction
|
||||
if len(predictions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Filter to next 7 days and sort by days until
|
||||
var upcoming []patterns.FailurePrediction
|
||||
for _, p := range predictions {
|
||||
@@ -539,11 +542,11 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction
|
||||
upcoming = append(upcoming, p)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
sort.Slice(upcoming, func(a, b int) bool {
|
||||
return upcoming[a].DaysUntil < upcoming[b].DaysUntil
|
||||
})
|
||||
|
||||
|
||||
if len(upcoming) > limit {
|
||||
upcoming = upcoming[:limit]
|
||||
}
|
||||
@@ -552,7 +555,7 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction
|
||||
|
||||
func (i *Intelligence) getLearningStats() LearningStats {
|
||||
stats := LearningStats{}
|
||||
|
||||
|
||||
if i.knowledge != nil {
|
||||
guests, _ := i.knowledge.ListGuests()
|
||||
for _, guestID := range guests {
|
||||
@@ -562,27 +565,27 @@ func (i *Intelligence) getLearningStats() LearningStats {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if i.baselines != nil {
|
||||
stats.ResourcesWithBaselines = i.baselines.ResourceCount()
|
||||
}
|
||||
|
||||
|
||||
if i.patterns != nil {
|
||||
p := i.patterns.GetPatterns()
|
||||
stats.PatternsDetected = len(p)
|
||||
}
|
||||
|
||||
|
||||
if i.correlations != nil {
|
||||
c := i.correlations.GetCorrelations()
|
||||
stats.CorrelationsLearned = len(c)
|
||||
}
|
||||
|
||||
|
||||
if i.incidents != nil {
|
||||
// Count is not available, so we skip this stat for now
|
||||
// Could be added to IncidentStore if needed
|
||||
stats.IncidentsTracked = 0
|
||||
}
|
||||
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
@@ -593,7 +596,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
|
||||
Trend: HealthTrendStable,
|
||||
Factors: []HealthFactor{},
|
||||
}
|
||||
|
||||
|
||||
// Deduct for findings
|
||||
if summary.FindingsCount.Critical > 0 {
|
||||
impact := float64(summary.FindingsCount.Critical) * 20
|
||||
@@ -608,7 +611,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
|
||||
Category: "finding",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
if summary.FindingsCount.Warning > 0 {
|
||||
impact := float64(summary.FindingsCount.Warning) * 10
|
||||
if impact > 20 {
|
||||
@@ -622,7 +625,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
|
||||
Category: "finding",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Deduct for imminent predictions
|
||||
for _, pred := range summary.UpcomingRisks {
|
||||
if pred.DaysUntil < 3 && pred.Confidence > 0.7 {
|
||||
@@ -636,7 +639,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Bonus for learning progress
|
||||
if summary.Learning.ResourcesWithKnowledge > 5 {
|
||||
bonus := 5.0
|
||||
@@ -648,7 +651,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
|
||||
Category: "learning",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Clamp score
|
||||
if health.Score < 0 {
|
||||
health.Score = 0
|
||||
@@ -656,13 +659,13 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal
|
||||
if health.Score > 100 {
|
||||
health.Score = 100
|
||||
}
|
||||
|
||||
|
||||
// Assign grade
|
||||
health.Grade = scoreToGrade(health.Score)
|
||||
|
||||
|
||||
// Generate prediction text
|
||||
health.Prediction = i.generateHealthPrediction(health, summary)
|
||||
|
||||
|
||||
return health
|
||||
}
|
||||
|
||||
@@ -673,7 +676,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
|
||||
Trend: HealthTrendStable,
|
||||
Factors: []HealthFactor{},
|
||||
}
|
||||
|
||||
|
||||
// Deduct for active findings
|
||||
for _, f := range intel.ActiveFindings {
|
||||
if f == nil {
|
||||
@@ -698,7 +701,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
|
||||
Category: "finding",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Deduct for predictions
|
||||
for _, p := range intel.Predictions {
|
||||
if p.DaysUntil < 7 && p.Confidence > 0.5 {
|
||||
@@ -712,7 +715,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Deduct for anomalies
|
||||
for _, a := range intel.Anomalies {
|
||||
var impact float64
|
||||
@@ -734,7 +737,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
|
||||
Category: "baseline",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Bonus for having knowledge
|
||||
if intel.NoteCount > 0 {
|
||||
bonus := 2.0
|
||||
@@ -746,7 +749,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
|
||||
Category: "learning",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Clamp
|
||||
if health.Score < 0 {
|
||||
health.Score = 0
|
||||
@@ -754,9 +757,9 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal
|
||||
if health.Score > 100 {
|
||||
health.Score = 100
|
||||
}
|
||||
|
||||
|
||||
health.Grade = scoreToGrade(health.Score)
|
||||
|
||||
|
||||
return health
|
||||
}
|
||||
|
||||
@@ -779,21 +782,21 @@ func (i *Intelligence) generateHealthPrediction(health HealthScore, summary *Int
|
||||
if health.Grade == HealthGradeA {
|
||||
return "Infrastructure is healthy with no significant issues detected."
|
||||
}
|
||||
|
||||
|
||||
if summary.FindingsCount.Critical > 0 {
|
||||
return fmt.Sprintf("Immediate attention required: %d critical issues.", summary.FindingsCount.Critical)
|
||||
}
|
||||
|
||||
|
||||
if len(summary.UpcomingRisks) > 0 {
|
||||
risk := summary.UpcomingRisks[0]
|
||||
return fmt.Sprintf("Predicted %s event on resource within %.1f days (%.0f%% confidence).",
|
||||
risk.EventType, risk.DaysUntil, risk.Confidence*100)
|
||||
}
|
||||
|
||||
|
||||
if summary.FindingsCount.Warning > 0 {
|
||||
return fmt.Sprintf("%d warnings should be addressed soon to maintain stability.", summary.FindingsCount.Warning)
|
||||
}
|
||||
|
||||
|
||||
return "Infrastructure is stable with minor issues to monitor."
|
||||
}
|
||||
|
||||
@@ -801,16 +804,13 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
|
||||
if i.findings == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Group findings by resource
|
||||
byResource := make(map[string][]*Finding)
|
||||
for _, f := range i.findings.GetActive(FindingSeverityInfo) {
|
||||
if f == nil {
|
||||
continue
|
||||
}
|
||||
byResource[f.ResourceID] = append(byResource[f.ResourceID], f)
|
||||
}
|
||||
|
||||
|
||||
// Calculate risk for each resource
|
||||
type resourceRisk struct {
|
||||
id string
|
||||
@@ -819,13 +819,9 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
|
||||
score float64
|
||||
top string
|
||||
}
|
||||
|
||||
|
||||
var risks []resourceRisk
|
||||
for id, findings := range byResource {
|
||||
if len(findings) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
score := 0.0
|
||||
var topFinding *Finding
|
||||
for _, f := range findings {
|
||||
@@ -843,7 +839,7 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
|
||||
topFinding = f
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if score > 0 && topFinding != nil {
|
||||
risks = append(risks, resourceRisk{
|
||||
id: id,
|
||||
@@ -854,16 +850,16 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Sort by risk score descending
|
||||
sort.Slice(risks, func(a, b int) bool {
|
||||
return risks[a].score > risks[b].score
|
||||
})
|
||||
|
||||
|
||||
if len(risks) > limit {
|
||||
risks = risks[:limit]
|
||||
}
|
||||
|
||||
|
||||
// Convert to summaries
|
||||
var summaries []ResourceRiskSummary
|
||||
for _, r := range risks {
|
||||
@@ -879,11 +875,14 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary {
|
||||
TopIssue: r.top,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
func (i *Intelligence) detectCurrentAnomalies(resourceID string) []AnomalyReport {
|
||||
if i.anomalyDetector != nil {
|
||||
return i.anomalyDetector(resourceID)
|
||||
}
|
||||
// This would be called with current metrics from state
|
||||
// For now, return empty - will be integrated with patrol
|
||||
return nil
|
||||
@@ -893,21 +892,21 @@ func (i *Intelligence) formatBaselinesForContext(resourceID string) string {
|
||||
if i.baselines == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
rb, ok := i.baselines.GetResourceBaseline(resourceID)
|
||||
if !ok || len(rb.Metrics) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, "\n## Learned Baselines")
|
||||
lines = append(lines, "Normal operating ranges for this resource:")
|
||||
|
||||
|
||||
for metric, mb := range rb.Metrics {
|
||||
lines = append(lines, fmt.Sprintf("- %s: mean %.1f, stddev %.1f (samples: %d)",
|
||||
metric, mb.Mean, mb.StdDev, mb.SampleCount))
|
||||
}
|
||||
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
@@ -915,15 +914,15 @@ func (i *Intelligence) formatAnomaliesForContext(anomalies []AnomalyReport) stri
|
||||
if len(anomalies) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, "\n## Current Anomalies")
|
||||
lines = append(lines, "Metrics deviating from normal:")
|
||||
|
||||
|
||||
for _, a := range anomalies {
|
||||
lines = append(lines, fmt.Sprintf("- %s: %s", a.Metric, a.Description))
|
||||
}
|
||||
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
|
||||
552
internal/ai/intelligence_coverage_test.go
Normal file
552
internal/ai/intelligence_coverage_test.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/baseline"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/correlation"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/memory"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/patterns"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
|
||||
)
|
||||
|
||||
func TestIntelligence_formatBaselinesForContext(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
store := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
|
||||
if err := store.Learn("res-1", "vm", "cpu", []baseline.MetricPoint{{Value: 12.5}}); err != nil {
|
||||
t.Fatalf("Learn: %v", err)
|
||||
}
|
||||
intel.SetSubsystems(nil, nil, nil, store, nil, nil, nil, nil)
|
||||
|
||||
ctx := intel.formatBaselinesForContext("res-1")
|
||||
if !strings.Contains(ctx, "Learned Baselines") {
|
||||
t.Fatalf("expected baseline header, got %q", ctx)
|
||||
}
|
||||
if !strings.Contains(ctx, "cpu: mean") {
|
||||
t.Fatalf("expected cpu baseline line, got %q", ctx)
|
||||
}
|
||||
|
||||
if got := intel.formatBaselinesForContext("missing"); got != "" {
|
||||
t.Errorf("expected empty context for missing baseline, got %q", got)
|
||||
}
|
||||
|
||||
empty := NewIntelligence(IntelligenceConfig{})
|
||||
if got := empty.formatBaselinesForContext("res-1"); got != "" {
|
||||
t.Errorf("expected empty context with no baseline store, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_formatAnomaliesForContext(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
if got := intel.formatAnomaliesForContext(nil); got != "" {
|
||||
t.Errorf("expected empty anomalies context, got %q", got)
|
||||
}
|
||||
|
||||
anomalies := []AnomalyReport{{Metric: "cpu", Description: "CPU high"}}
|
||||
ctx := intel.formatAnomaliesForContext(anomalies)
|
||||
if !strings.Contains(ctx, "Current Anomalies") {
|
||||
t.Fatalf("expected anomalies header, got %q", ctx)
|
||||
}
|
||||
if !strings.Contains(ctx, "CPU high") {
|
||||
t.Fatalf("expected anomaly description, got %q", ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_formatAnomalyDescription(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
bl := &baseline.MetricBaseline{Mean: 10}
|
||||
|
||||
above := intel.formatAnomalyDescription("cpu", 20, bl, 2.5)
|
||||
if !strings.Contains(above, "above baseline") {
|
||||
t.Errorf("expected above-baseline description, got %q", above)
|
||||
}
|
||||
|
||||
below := intel.formatAnomalyDescription("cpu", 5, bl, -1.5)
|
||||
if !strings.Contains(below, "below baseline") {
|
||||
t.Errorf("expected below-baseline description, got %q", below)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_FormatContext_Anomalies(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
intel.anomalyDetector = func(resourceID string) []AnomalyReport {
|
||||
return []AnomalyReport{{Metric: "cpu", Description: "CPU high"}}
|
||||
}
|
||||
|
||||
ctx := intel.FormatContext("res-1")
|
||||
if !strings.Contains(ctx, "Current Anomalies") {
|
||||
t.Fatalf("expected anomalies context, got %q", ctx)
|
||||
}
|
||||
if !strings.Contains(ctx, "CPU high") {
|
||||
t.Fatalf("expected anomaly description, got %q", ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_getUpcomingRisks(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
predictions := []patterns.FailurePrediction{
|
||||
{ResourceID: "r1", EventType: patterns.EventHighCPU, DaysUntil: 2, Confidence: 0.6},
|
||||
{ResourceID: "r2", EventType: patterns.EventHighMemory, DaysUntil: 9, Confidence: 0.9},
|
||||
{ResourceID: "r3", EventType: patterns.EventDiskFull, DaysUntil: 4, Confidence: 0.4},
|
||||
{ResourceID: "r4", EventType: patterns.EventOOM, DaysUntil: 1, Confidence: 0.8},
|
||||
{ResourceID: "r5", EventType: patterns.EventRestart, DaysUntil: 5, Confidence: 0.7},
|
||||
}
|
||||
|
||||
risk := intel.getUpcomingRisks(predictions, 2)
|
||||
if len(risk) != 2 {
|
||||
t.Fatalf("expected 2 risks, got %d", len(risk))
|
||||
}
|
||||
if risk[0].ResourceID != "r4" || risk[1].ResourceID != "r1" {
|
||||
t.Errorf("unexpected ordering: %v", []string{risk[0].ResourceID, risk[1].ResourceID})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_countFindings(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
findings := []*Finding{
|
||||
nil,
|
||||
{Severity: FindingSeverityCritical},
|
||||
{Severity: FindingSeverityWarning},
|
||||
{Severity: FindingSeverityWatch},
|
||||
{Severity: FindingSeverityInfo},
|
||||
}
|
||||
counts := intel.countFindings(findings)
|
||||
if counts.Total != 4 {
|
||||
t.Errorf("expected total 4, got %d", counts.Total)
|
||||
}
|
||||
if counts.Critical != 1 || counts.Warning != 1 || counts.Watch != 1 || counts.Info != 1 {
|
||||
t.Errorf("unexpected counts: %+v", counts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_getTopFindings(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
base := time.Now()
|
||||
findings := []*Finding{
|
||||
{Severity: FindingSeverityWarning, DetectedAt: base.Add(-2 * time.Hour)},
|
||||
{Severity: FindingSeverityCritical, DetectedAt: base.Add(-3 * time.Hour), Title: "older critical"},
|
||||
{Severity: FindingSeverityCritical, DetectedAt: base.Add(-1 * time.Hour), Title: "newer critical"},
|
||||
{Severity: FindingSeverityWatch, DetectedAt: base.Add(-30 * time.Minute)},
|
||||
}
|
||||
|
||||
top := intel.getTopFindings(findings, 3)
|
||||
if len(top) != 3 {
|
||||
t.Fatalf("expected 3 findings, got %d", len(top))
|
||||
}
|
||||
if top[0].Title != "newer critical" || top[1].Title != "older critical" {
|
||||
t.Errorf("unexpected ordering: %q, %q", top[0].Title, top[1].Title)
|
||||
}
|
||||
if top[2].Severity != FindingSeverityWarning {
|
||||
t.Errorf("expected warning in third position, got %s", top[2].Severity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_getTopFindings_Empty(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
if got := intel.getTopFindings(nil, 5); got != nil {
|
||||
t.Errorf("expected nil for empty findings, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_getLearningStats(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
knowledgeStore, err := knowledge.NewStore(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore: %v", err)
|
||||
}
|
||||
knowledgeStore.SaveNote("vm-1", "vm-1", "vm", "general", "Note", "Content")
|
||||
knowledgeStore.SaveNote("vm-2", "vm-2", "vm", "general", "Note", "Content")
|
||||
|
||||
baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
|
||||
baselineStore.Learn("vm-1", "vm", "cpu", []baseline.MetricPoint{{Value: 10}})
|
||||
|
||||
patternDetector := patterns.NewDetector(patterns.DetectorConfig{
|
||||
MinOccurrences: 2,
|
||||
PatternWindow: 48 * time.Hour,
|
||||
PredictionLimit: 30 * 24 * time.Hour,
|
||||
})
|
||||
patternStart := time.Now().Add(-90 * time.Minute)
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-1", EventType: patterns.EventHighCPU, Timestamp: patternStart})
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-1", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
|
||||
|
||||
correlationDetector := correlation.NewDetector(correlation.Config{
|
||||
MinOccurrences: 1,
|
||||
CorrelationWindow: 2 * time.Hour,
|
||||
RetentionWindow: 24 * time.Hour,
|
||||
})
|
||||
corrStart := time.Now().Add(-30 * time.Minute)
|
||||
for i := 0; i < 2; i++ {
|
||||
base := corrStart.Add(time.Duration(i) * 10 * time.Minute)
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-a", ResourceName: "node-a", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: base})
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-b", ResourceName: "vm-b", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: base.Add(1 * time.Minute)})
|
||||
}
|
||||
|
||||
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 5})
|
||||
|
||||
intel.SetSubsystems(nil, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil)
|
||||
|
||||
stats := intel.getLearningStats()
|
||||
if stats.ResourcesWithKnowledge != 2 {
|
||||
t.Errorf("expected 2 resources with knowledge, got %d", stats.ResourcesWithKnowledge)
|
||||
}
|
||||
if stats.TotalNotes != 2 {
|
||||
t.Errorf("expected 2 total notes, got %d", stats.TotalNotes)
|
||||
}
|
||||
if stats.ResourcesWithBaselines != 1 {
|
||||
t.Errorf("expected 1 resource with baseline, got %d", stats.ResourcesWithBaselines)
|
||||
}
|
||||
if stats.PatternsDetected != 1 {
|
||||
t.Errorf("expected 1 pattern, got %d", stats.PatternsDetected)
|
||||
}
|
||||
if stats.CorrelationsLearned != 1 {
|
||||
t.Errorf("expected 1 correlation, got %d", stats.CorrelationsLearned)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_getResourcesAtRisk(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
findings := NewFindingsStore()
|
||||
findings.Add(&Finding{
|
||||
ID: "crit-1",
|
||||
Key: "crit-1",
|
||||
Severity: FindingSeverityCritical,
|
||||
Category: FindingCategoryReliability,
|
||||
ResourceID: "res-critical",
|
||||
ResourceName: "critical-vm",
|
||||
ResourceType: "vm",
|
||||
Title: "Critical outage",
|
||||
DetectedAt: time.Now(),
|
||||
LastSeenAt: time.Now(),
|
||||
Source: "test",
|
||||
})
|
||||
findings.Add(&Finding{
|
||||
ID: "warn-1",
|
||||
Key: "warn-1",
|
||||
Severity: FindingSeverityWarning,
|
||||
Category: FindingCategoryPerformance,
|
||||
ResourceID: "res-warning",
|
||||
ResourceName: "warning-vm",
|
||||
ResourceType: "vm",
|
||||
Title: "Warning issue",
|
||||
DetectedAt: time.Now(),
|
||||
LastSeenAt: time.Now(),
|
||||
Source: "test",
|
||||
})
|
||||
findings.Add(&Finding{
|
||||
ID: "watch-1",
|
||||
Key: "watch-1",
|
||||
Severity: FindingSeverityWatch,
|
||||
Category: FindingCategoryPerformance,
|
||||
ResourceID: "res-warning",
|
||||
ResourceName: "warning-vm",
|
||||
ResourceType: "vm",
|
||||
Title: "Watch issue",
|
||||
DetectedAt: time.Now(),
|
||||
LastSeenAt: time.Now(),
|
||||
Source: "test",
|
||||
})
|
||||
findings.Add(&Finding{
|
||||
ID: "info-1",
|
||||
Key: "info-1",
|
||||
Severity: FindingSeverityInfo,
|
||||
Category: FindingCategoryPerformance,
|
||||
ResourceID: "res-warning",
|
||||
ResourceName: "warning-vm",
|
||||
ResourceType: "vm",
|
||||
Title: "Info issue",
|
||||
DetectedAt: time.Now(),
|
||||
LastSeenAt: time.Now(),
|
||||
Source: "test",
|
||||
})
|
||||
|
||||
intel.SetSubsystems(findings, nil, nil, nil, nil, nil, nil, nil)
|
||||
risks := intel.getResourcesAtRisk(1)
|
||||
if len(risks) != 1 {
|
||||
t.Fatalf("expected 1 risk, got %d", len(risks))
|
||||
}
|
||||
if risks[0].ResourceID != "res-critical" {
|
||||
t.Errorf("expected critical resource first, got %s", risks[0].ResourceID)
|
||||
}
|
||||
if risks[0].TopIssue != "Critical outage" {
|
||||
t.Errorf("expected top issue to be critical, got %s", risks[0].TopIssue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_calculateResourceHealth_ClampAndSeverities(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
|
||||
resourceIntel := &ResourceIntelligence{
|
||||
ResourceID: "test-vm",
|
||||
ActiveFindings: []*Finding{
|
||||
nil,
|
||||
{Severity: FindingSeverityCritical, Title: "crit-1"},
|
||||
{Severity: FindingSeverityCritical, Title: "crit-2"},
|
||||
{Severity: FindingSeverityCritical, Title: "crit-3"},
|
||||
{Severity: FindingSeverityCritical, Title: "crit-4"},
|
||||
{Severity: FindingSeverityWatch, Title: "watch"},
|
||||
{Severity: FindingSeverityInfo, Title: "info"},
|
||||
},
|
||||
Anomalies: []AnomalyReport{
|
||||
{Metric: "cpu", Severity: baseline.AnomalyHigh, Description: "high"},
|
||||
{Metric: "disk", Severity: baseline.AnomalyLow, Description: "low"},
|
||||
},
|
||||
}
|
||||
|
||||
health := intel.calculateResourceHealth(resourceIntel)
|
||||
if health.Score != 0 {
|
||||
t.Errorf("expected score clamped to 0, got %f", health.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_calculateOverallHealth_Clamps(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
var predictions []patterns.FailurePrediction
|
||||
for i := 0; i < 10; i++ {
|
||||
predictions = append(predictions, patterns.FailurePrediction{EventType: patterns.EventHighCPU, DaysUntil: 1, Confidence: 0.9})
|
||||
}
|
||||
negative := intel.calculateOverallHealth(&IntelligenceSummary{
|
||||
FindingsCount: FindingsCounts{Critical: 10, Warning: 10},
|
||||
UpcomingRisks: predictions,
|
||||
})
|
||||
if negative.Score != 0 {
|
||||
t.Errorf("expected score clamped to 0, got %f", negative.Score)
|
||||
}
|
||||
|
||||
positive := intel.calculateOverallHealth(&IntelligenceSummary{
|
||||
Learning: LearningStats{ResourcesWithKnowledge: 10},
|
||||
})
|
||||
if positive.Score != 100 {
|
||||
t.Errorf("expected score clamped to 100, got %f", positive.Score)
|
||||
}
|
||||
if len(positive.Factors) == 0 {
|
||||
t.Error("expected learning factor for positive health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_generateHealthPrediction_Branches(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
|
||||
gradeA := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeA}, &IntelligenceSummary{})
|
||||
if !strings.Contains(gradeA, "healthy") {
|
||||
t.Errorf("expected healthy prediction, got %q", gradeA)
|
||||
}
|
||||
|
||||
critical := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeB}, &IntelligenceSummary{
|
||||
FindingsCount: FindingsCounts{Critical: 2},
|
||||
})
|
||||
if !strings.Contains(critical, "Immediate attention") {
|
||||
t.Errorf("expected critical prediction, got %q", critical)
|
||||
}
|
||||
|
||||
risk := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeB}, &IntelligenceSummary{
|
||||
UpcomingRisks: []patterns.FailurePrediction{{EventType: patterns.EventHighCPU, DaysUntil: 2, Confidence: 0.8}},
|
||||
})
|
||||
if !strings.Contains(risk, "Predicted") {
|
||||
t.Errorf("expected prediction text, got %q", risk)
|
||||
}
|
||||
|
||||
stable := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeC}, &IntelligenceSummary{})
|
||||
if !strings.Contains(stable, "stable") {
|
||||
t.Errorf("expected stable prediction, got %q", stable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_FormatContext_AllSubsystems(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
|
||||
knowledgeStore.SaveNote("vm-ctx", "vm-ctx", "vm", "general", "Note", "Content")
|
||||
|
||||
baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
|
||||
baselineStore.Learn("vm-ctx", "vm", "cpu", []baseline.MetricPoint{{Value: 10}})
|
||||
|
||||
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
|
||||
patternStart := time.Now().Add(-90 * time.Minute)
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-ctx", EventType: patterns.EventHighCPU, Timestamp: patternStart})
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-ctx", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
|
||||
|
||||
correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour})
|
||||
corrBase := time.Now().Add(-30 * time.Minute)
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-ctx", ResourceName: "vm-ctx", ResourceType: "vm", EventType: correlation.EventHighCPU, Timestamp: corrBase})
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-ctx", ResourceName: "node-ctx", ResourceType: "node", EventType: correlation.EventRestart, Timestamp: corrBase.Add(1 * time.Minute)})
|
||||
|
||||
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10})
|
||||
incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-ctx", ResourceID: "vm-ctx", ResourceName: "vm-ctx", Type: "cpu", StartTime: time.Now()})
|
||||
|
||||
intel.SetSubsystems(nil, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil)
|
||||
ctx := intel.FormatContext("vm-ctx")
|
||||
if !strings.Contains(ctx, "Learned Baselines") {
|
||||
t.Fatalf("expected baseline context, got %q", ctx)
|
||||
}
|
||||
if !strings.Contains(ctx, "Failure Predictions") {
|
||||
t.Fatalf("expected predictions context, got %q", ctx)
|
||||
}
|
||||
if !strings.Contains(ctx, "Resource Correlations") {
|
||||
t.Fatalf("expected correlations context, got %q", ctx)
|
||||
}
|
||||
if !strings.Contains(ctx, "Incident Memory") {
|
||||
t.Fatalf("expected incident context, got %q", ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_FormatGlobalContext_AllSubsystems(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
|
||||
knowledgeStore.SaveNote("vm-global", "vm-global", "vm", "general", "Note", "Content")
|
||||
|
||||
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10})
|
||||
incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-global", ResourceID: "vm-global", ResourceName: "vm-global", Type: "cpu", StartTime: time.Now()})
|
||||
|
||||
correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour})
|
||||
corrBase := time.Now().Add(-30 * time.Minute)
|
||||
for i := 0; i < 2; i++ {
|
||||
base := corrBase.Add(time.Duration(i) * 10 * time.Minute)
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-a", ResourceName: "node-a", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: base})
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-global", ResourceName: "vm-global", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: base.Add(1 * time.Minute)})
|
||||
}
|
||||
|
||||
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
|
||||
patternStart := time.Now().Add(-90 * time.Minute)
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-global", EventType: patterns.EventHighCPU, Timestamp: patternStart})
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-global", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
|
||||
|
||||
intel.SetSubsystems(nil, patternDetector, correlationDetector, nil, incidentStore, knowledgeStore, nil, nil)
|
||||
ctx := intel.FormatGlobalContext()
|
||||
if ctx == "" {
|
||||
t.Fatal("expected non-empty global context")
|
||||
}
|
||||
if !strings.Contains(ctx, "Resource Correlations") {
|
||||
t.Fatalf("expected correlations in global context, got %q", ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_GetSummary_WithSubsystems(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
findings := NewFindingsStore()
|
||||
findings.Add(&Finding{
|
||||
ID: "f1",
|
||||
Key: "f1",
|
||||
Severity: FindingSeverityWarning,
|
||||
Category: FindingCategoryPerformance,
|
||||
ResourceID: "vm-sum",
|
||||
ResourceName: "vm-sum",
|
||||
ResourceType: "vm",
|
||||
Title: "Warning",
|
||||
DetectedAt: time.Now(),
|
||||
LastSeenAt: time.Now(),
|
||||
Source: "test",
|
||||
})
|
||||
|
||||
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
|
||||
patternStart := time.Now().Add(-90 * time.Minute)
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-sum", EventType: patterns.EventHighCPU, Timestamp: patternStart})
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-sum", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
|
||||
|
||||
changes := memory.NewChangeDetector(memory.ChangeDetectorConfig{MaxChanges: 10})
|
||||
changes.DetectChanges([]memory.ResourceSnapshot{{ID: "vm-sum", Name: "vm-sum", Type: "vm", Status: "running", SnapshotTime: time.Now()}})
|
||||
|
||||
remediations := memory.NewRemediationLog(memory.RemediationLogConfig{MaxRecords: 10})
|
||||
remediations.Log(memory.RemediationRecord{ResourceID: "vm-sum", Problem: "cpu", Action: "restart", Outcome: memory.OutcomeResolved})
|
||||
|
||||
intel.SetSubsystems(findings, patternDetector, nil, nil, nil, nil, changes, remediations)
|
||||
|
||||
summary := intel.GetSummary()
|
||||
if summary.FindingsCount.Total == 0 {
|
||||
t.Error("expected findings in summary")
|
||||
}
|
||||
if summary.PredictionsCount == 0 {
|
||||
t.Error("expected predictions in summary")
|
||||
}
|
||||
if len(summary.UpcomingRisks) == 0 {
|
||||
t.Error("expected upcoming risks in summary")
|
||||
}
|
||||
if summary.RecentChangesCount == 0 {
|
||||
t.Error("expected recent changes in summary")
|
||||
}
|
||||
if len(summary.RecentRemediations) == 0 {
|
||||
t.Error("expected recent remediations in summary")
|
||||
}
|
||||
if len(summary.ResourcesAtRisk) == 0 {
|
||||
t.Error("expected resources at risk in summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_GetResourceIntelligence_WithAllSubsystems(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
findings := NewFindingsStore()
|
||||
findings.Add(&Finding{
|
||||
ID: "f1",
|
||||
Key: "f1",
|
||||
Severity: FindingSeverityWarning,
|
||||
Category: FindingCategoryPerformance,
|
||||
ResourceID: "vm-intel",
|
||||
ResourceName: "vm-intel",
|
||||
ResourceType: "vm",
|
||||
Title: "Warning",
|
||||
DetectedAt: time.Now(),
|
||||
LastSeenAt: time.Now(),
|
||||
Source: "test",
|
||||
})
|
||||
|
||||
patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour})
|
||||
patternStart := time.Now().Add(-90 * time.Minute)
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-intel", EventType: patterns.EventHighCPU, Timestamp: patternStart})
|
||||
patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-intel", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)})
|
||||
|
||||
correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour})
|
||||
corrBase := time.Now().Add(-30 * time.Minute)
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-intel", ResourceName: "node-intel", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: corrBase})
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-intel", ResourceName: "vm-intel", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: corrBase.Add(1 * time.Minute)})
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-intel", ResourceName: "vm-intel", ResourceType: "vm", EventType: correlation.EventHighCPU, Timestamp: corrBase.Add(2 * time.Minute)})
|
||||
correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-child", ResourceName: "vm-child", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: corrBase.Add(3 * time.Minute)})
|
||||
|
||||
baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1})
|
||||
baselineStore.Learn("vm-intel", "vm", "cpu", []baseline.MetricPoint{{Value: 10}})
|
||||
|
||||
incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10})
|
||||
incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-intel", ResourceID: "vm-intel", ResourceName: "vm-intel", Type: "cpu", StartTime: time.Now()})
|
||||
|
||||
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
|
||||
knowledgeStore.SaveNote("vm-intel", "vm-intel", "vm", "general", "Note", "Content")
|
||||
|
||||
intel.SetSubsystems(findings, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil)
|
||||
res := intel.GetResourceIntelligence("vm-intel")
|
||||
if len(res.ActiveFindings) == 0 {
|
||||
t.Error("expected active findings")
|
||||
}
|
||||
if len(res.Predictions) == 0 {
|
||||
t.Error("expected predictions")
|
||||
}
|
||||
if len(res.Correlations) == 0 {
|
||||
t.Error("expected correlations")
|
||||
}
|
||||
if len(res.Dependents) == 0 || len(res.Dependencies) == 0 {
|
||||
t.Error("expected dependencies and dependents")
|
||||
}
|
||||
if len(res.Baselines) == 0 {
|
||||
t.Error("expected baselines")
|
||||
}
|
||||
if len(res.RecentIncidents) == 0 {
|
||||
t.Error("expected incidents")
|
||||
}
|
||||
if res.Knowledge == nil || res.NoteCount == 0 {
|
||||
t.Error("expected knowledge")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntelligence_GetResourceIntelligence_KnowledgeFallback(t *testing.T) {
|
||||
intel := NewIntelligence(IntelligenceConfig{})
|
||||
knowledgeStore, _ := knowledge.NewStore(t.TempDir())
|
||||
knowledgeStore.SaveNote("vm-know", "knowledge-vm", "vm", "general", "Note", "Content")
|
||||
|
||||
intel.SetSubsystems(nil, nil, nil, nil, nil, knowledgeStore, nil, nil)
|
||||
res := intel.GetResourceIntelligence("vm-know")
|
||||
if res.ResourceName != "knowledge-vm" {
|
||||
t.Errorf("expected resource name from knowledge, got %q", res.ResourceName)
|
||||
}
|
||||
if res.ResourceType != "vm" {
|
||||
t.Errorf("expected resource type from knowledge, got %q", res.ResourceType)
|
||||
}
|
||||
}
|
||||
@@ -167,9 +167,6 @@ func (s *IncidentStore) RecordAlertAcknowledged(alert *alerts.Alert, user string
|
||||
defer s.mu.Unlock()
|
||||
|
||||
incident := s.ensureIncidentForAlertLocked(alert)
|
||||
if incident == nil {
|
||||
return
|
||||
}
|
||||
|
||||
incident.Acknowledged = true
|
||||
if alert.AckTime != nil {
|
||||
@@ -198,9 +195,6 @@ func (s *IncidentStore) RecordAlertUnacknowledged(alert *alerts.Alert, user stri
|
||||
defer s.mu.Unlock()
|
||||
|
||||
incident := s.ensureIncidentForAlertLocked(alert)
|
||||
if incident == nil {
|
||||
return
|
||||
}
|
||||
|
||||
incident.Acknowledged = false
|
||||
incident.AckTime = nil
|
||||
|
||||
1128
internal/ai/memory/memory_coverage_test.go
Normal file
1128
internal/ai/memory/memory_coverage_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -52,8 +52,8 @@ type mockThresholdProvider struct {
|
||||
storage float64
|
||||
}
|
||||
|
||||
func (m *mockThresholdProvider) GetNodeCPUThreshold() float64 { return m.nodeCPU }
|
||||
func (m *mockThresholdProvider) GetNodeMemoryThreshold() float64 { return m.nodeMemory }
|
||||
func (m *mockThresholdProvider) GetNodeCPUThreshold() float64 { return m.nodeCPU }
|
||||
func (m *mockThresholdProvider) GetNodeMemoryThreshold() float64 { return m.nodeMemory }
|
||||
func (m *mockThresholdProvider) GetGuestCPUThreshold() float64 { return 0 }
|
||||
func (m *mockThresholdProvider) GetGuestMemoryThreshold() float64 { return m.guestMem }
|
||||
func (m *mockThresholdProvider) GetGuestDiskThreshold() float64 { return m.guestDisk }
|
||||
@@ -61,11 +61,17 @@ func (m *mockThresholdProvider) GetStorageThreshold() float64 { return m.sto
|
||||
|
||||
type mockResourceProvider struct {
|
||||
ResourceProvider
|
||||
getAllFunc func() []resources.Resource
|
||||
getStatsFunc func() resources.StoreStats
|
||||
getSummaryFunc func() resources.ResourceSummary
|
||||
getAllFunc func() []resources.Resource
|
||||
getStatsFunc func() resources.StoreStats
|
||||
getSummaryFunc func() resources.ResourceSummary
|
||||
getInfrastructureFunc func() []resources.Resource
|
||||
getWorkloadsFunc func() []resources.Resource
|
||||
getWorkloadsFunc func() []resources.Resource
|
||||
getByTypeFunc func(t resources.ResourceType) []resources.Resource
|
||||
getTopCPUFunc func(limit int, types []resources.ResourceType) []resources.Resource
|
||||
getTopMemoryFunc func(limit int, types []resources.ResourceType) []resources.Resource
|
||||
getTopDiskFunc func(limit int, types []resources.ResourceType) []resources.Resource
|
||||
getRelatedFunc func(resourceID string) map[string][]resources.Resource
|
||||
findContainerHostFunc func(containerNameOrID string) string
|
||||
}
|
||||
|
||||
func (m *mockResourceProvider) GetAll() []resources.Resource {
|
||||
@@ -99,23 +105,46 @@ func (m *mockResourceProvider) GetWorkloads() []resources.Resource {
|
||||
return nil
|
||||
}
|
||||
func (m *mockResourceProvider) GetType(t resources.ResourceType) []resources.Resource { return nil }
|
||||
func (m *mockResourceProvider) GetByType(t resources.ResourceType) []resources.Resource {
|
||||
if m.getByTypeFunc != nil {
|
||||
return m.getByTypeFunc(t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockResourceProvider) GetTopByCPU(limit int, types []resources.ResourceType) []resources.Resource {
|
||||
if m.getTopCPUFunc != nil {
|
||||
return m.getTopCPUFunc(limit, types)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockResourceProvider) GetTopByMemory(limit int, types []resources.ResourceType) []resources.Resource {
|
||||
if m.getTopMemoryFunc != nil {
|
||||
return m.getTopMemoryFunc(limit, types)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockResourceProvider) GetTopByDisk(limit int, types []resources.ResourceType) []resources.Resource {
|
||||
if m.getTopDiskFunc != nil {
|
||||
return m.getTopDiskFunc(limit, types)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockResourceProvider) GetRelated(resourceID string) map[string][]resources.Resource {
|
||||
if m.getRelatedFunc != nil {
|
||||
return m.getRelatedFunc(resourceID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockResourceProvider) FindContainerHost(containerNameOrID string) string { return "" }
|
||||
func (m *mockResourceProvider) FindContainerHost(containerNameOrID string) string {
|
||||
if m.findContainerHostFunc != nil {
|
||||
return m.findContainerHostFunc(containerNameOrID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type mockAgentServer struct {
|
||||
agents []agentexec.ConnectedAgent
|
||||
executeFunc func(ctx context.Context, agentID string, cmd agentexec.ExecuteCommandPayload) (*agentexec.CommandResultPayload, error)
|
||||
agents []agentexec.ConnectedAgent
|
||||
executeFunc func(ctx context.Context, agentID string, cmd agentexec.ExecuteCommandPayload) (*agentexec.CommandResultPayload, error)
|
||||
}
|
||||
|
||||
func (m *mockAgentServer) GetConnectedAgents() []agentexec.ConnectedAgent {
|
||||
|
||||
@@ -284,10 +284,6 @@ func (s *Service) buildUnifiedResourceContext() string {
|
||||
}
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := "\n\n" + strings.Join(sections, "\n")
|
||||
|
||||
// Limit context size
|
||||
|
||||
263
internal/ai/resource_context_test.go
Normal file
263
internal/ai/resource_context_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/agentexec"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/resources"
|
||||
)
|
||||
|
||||
func TestBuildUnifiedResourceContext_NilProvider(t *testing.T) {
|
||||
s := &Service{}
|
||||
if got := s.buildUnifiedResourceContext(); got != "" {
|
||||
t.Errorf("expected empty context, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUnifiedResourceContext_FullContext(t *testing.T) {
|
||||
nodeWithAgent := resources.Resource{
|
||||
ID: "node-1",
|
||||
Name: "delly",
|
||||
Type: resources.ResourceTypeNode,
|
||||
PlatformType: resources.PlatformProxmoxPVE,
|
||||
ClusterID: "cluster-a",
|
||||
Status: resources.StatusOnline,
|
||||
CPU: metricValue(12.3),
|
||||
Memory: metricValue(45.6),
|
||||
}
|
||||
nodeNoAgent := resources.Resource{
|
||||
ID: "node-2",
|
||||
Name: "minipc",
|
||||
Type: resources.ResourceTypeNode,
|
||||
PlatformType: resources.PlatformProxmoxPVE,
|
||||
Status: resources.StatusDegraded,
|
||||
}
|
||||
dockerNode := resources.Resource{
|
||||
ID: "dock-node",
|
||||
Name: "dock-node",
|
||||
Type: resources.ResourceTypeNode,
|
||||
PlatformType: resources.PlatformDocker,
|
||||
Status: resources.StatusOnline,
|
||||
}
|
||||
host := resources.Resource{
|
||||
ID: "host-1",
|
||||
Name: "barehost",
|
||||
Type: resources.ResourceTypeHost,
|
||||
PlatformType: resources.PlatformHostAgent,
|
||||
Status: resources.StatusOnline,
|
||||
Identity: &resources.ResourceIdentity{IPs: []string{"192.168.1.10"}},
|
||||
CPU: metricValue(5.0),
|
||||
Memory: metricValue(10.0),
|
||||
}
|
||||
dockerHost := resources.Resource{
|
||||
ID: "docker-1",
|
||||
Name: "dockhost",
|
||||
Type: resources.ResourceTypeDockerHost,
|
||||
PlatformType: resources.PlatformDocker,
|
||||
Status: resources.StatusRunning,
|
||||
}
|
||||
|
||||
vm := resources.Resource{
|
||||
ID: "vm-100",
|
||||
Name: "web-vm",
|
||||
Type: resources.ResourceTypeVM,
|
||||
ParentID: "node-1",
|
||||
PlatformID: "100",
|
||||
Status: resources.StatusRunning,
|
||||
Identity: &resources.ResourceIdentity{IPs: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}},
|
||||
CPU: metricValue(65.4),
|
||||
Memory: metricValue(70.2),
|
||||
}
|
||||
vm.Alerts = []resources.ResourceAlert{
|
||||
{ID: "alert-1", Message: "CPU high", Level: "critical"},
|
||||
}
|
||||
|
||||
ct := resources.Resource{
|
||||
ID: "ct-200",
|
||||
Name: "db-ct",
|
||||
Type: resources.ResourceTypeContainer,
|
||||
ParentID: "node-1",
|
||||
PlatformID: "200",
|
||||
Status: resources.StatusStopped,
|
||||
}
|
||||
dockerContainer := resources.Resource{
|
||||
ID: "dock-300",
|
||||
Name: "redis",
|
||||
Type: resources.ResourceTypeDockerContainer,
|
||||
ParentID: "docker-1",
|
||||
Status: resources.StatusRunning,
|
||||
Disk: metricValue(70.0),
|
||||
}
|
||||
dockerStopped := resources.Resource{
|
||||
ID: "dock-301",
|
||||
Name: "cache",
|
||||
Type: resources.ResourceTypeDockerContainer,
|
||||
ParentID: "docker-1",
|
||||
Status: resources.StatusStopped,
|
||||
}
|
||||
unknownParent := resources.Resource{
|
||||
ID: "vm-999",
|
||||
Name: "mystery",
|
||||
Type: resources.ResourceTypeVM,
|
||||
ParentID: "unknown-parent",
|
||||
PlatformID: "999",
|
||||
Status: resources.StatusRunning,
|
||||
}
|
||||
orphan := resources.Resource{
|
||||
ID: "orphan-1",
|
||||
Name: "orphan",
|
||||
Type: resources.ResourceTypeContainer,
|
||||
Status: resources.StatusRunning,
|
||||
Identity: &resources.ResourceIdentity{IPs: []string{"172.16.0.5"}},
|
||||
}
|
||||
|
||||
infrastructure := []resources.Resource{nodeWithAgent, nodeNoAgent, host, dockerHost, dockerNode}
|
||||
workloads := []resources.Resource{vm, ct, dockerContainer, dockerStopped, unknownParent, orphan}
|
||||
all := append(append([]resources.Resource{}, infrastructure...), workloads...)
|
||||
|
||||
stats := resources.StoreStats{
|
||||
TotalResources: len(all),
|
||||
ByType: map[resources.ResourceType]int{
|
||||
resources.ResourceTypeNode: 3,
|
||||
resources.ResourceTypeHost: 1,
|
||||
resources.ResourceTypeDockerHost: 1,
|
||||
resources.ResourceTypeVM: 2,
|
||||
resources.ResourceTypeContainer: 2,
|
||||
resources.ResourceTypeDockerContainer: 2,
|
||||
},
|
||||
}
|
||||
|
||||
summary := resources.ResourceSummary{
|
||||
TotalResources: len(all),
|
||||
Healthy: 6,
|
||||
Degraded: 1,
|
||||
Offline: 4,
|
||||
WithAlerts: 1,
|
||||
ByType: map[resources.ResourceType]resources.TypeSummary{
|
||||
resources.ResourceTypeVM: {Count: 2, AvgCPUPercent: 40.0, AvgMemoryPercent: 55.0},
|
||||
resources.ResourceTypeContainer: {Count: 2},
|
||||
},
|
||||
}
|
||||
|
||||
mockRP := &mockResourceProvider{
|
||||
getStatsFunc: func() resources.StoreStats {
|
||||
return stats
|
||||
},
|
||||
getInfrastructureFunc: func() []resources.Resource {
|
||||
return infrastructure
|
||||
},
|
||||
getWorkloadsFunc: func() []resources.Resource {
|
||||
return workloads
|
||||
},
|
||||
getAllFunc: func() []resources.Resource {
|
||||
return all
|
||||
},
|
||||
getSummaryFunc: func() resources.ResourceSummary {
|
||||
return summary
|
||||
},
|
||||
getTopCPUFunc: func(limit int, types []resources.ResourceType) []resources.Resource {
|
||||
return []resources.Resource{vm}
|
||||
},
|
||||
getTopMemoryFunc: func(limit int, types []resources.ResourceType) []resources.Resource {
|
||||
return []resources.Resource{host}
|
||||
},
|
||||
getTopDiskFunc: func(limit int, types []resources.ResourceType) []resources.Resource {
|
||||
return []resources.Resource{dockerContainer}
|
||||
},
|
||||
}
|
||||
|
||||
s := &Service{resourceProvider: mockRP}
|
||||
s.agentServer = &mockAgentServer{
|
||||
agents: []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "delly"},
|
||||
},
|
||||
}
|
||||
|
||||
got := s.buildUnifiedResourceContext()
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty context")
|
||||
}
|
||||
|
||||
assertContains := func(substr string) {
|
||||
t.Helper()
|
||||
if !strings.Contains(got, substr) {
|
||||
t.Fatalf("expected context to contain %q", substr)
|
||||
}
|
||||
}
|
||||
|
||||
assertContains("## Unified Infrastructure View")
|
||||
assertContains("Total resources: 11 (Infrastructure: 5, Workloads: 6)")
|
||||
assertContains("Proxmox VE Nodes")
|
||||
assertContains("HAS AGENT")
|
||||
assertContains("NO AGENT")
|
||||
assertContains("cluster: cluster-a")
|
||||
assertContains("Standalone Hosts")
|
||||
assertContains("192.168.1.10")
|
||||
assertContains("Docker/Podman Hosts")
|
||||
assertContains("1/2 containers running")
|
||||
assertContains("Workloads (VMs & Containers)")
|
||||
assertContains("On delly")
|
||||
assertContains("On unknown-parent")
|
||||
assertContains("Other workloads")
|
||||
assertContains("10.0.0.1, 10.0.0.2")
|
||||
assertContains("Resources with Active Alerts")
|
||||
assertContains("CPU high")
|
||||
assertContains("Infrastructure Summary")
|
||||
assertContains("Resources with alerts: 1")
|
||||
assertContains("Average utilization by type")
|
||||
assertContains("Top CPU Consumers")
|
||||
assertContains("Top Memory Consumers")
|
||||
assertContains("Top Disk Usage")
|
||||
}
|
||||
|
||||
func TestBuildUnifiedResourceContext_TruncatesLargeContext(t *testing.T) {
|
||||
largeName := strings.Repeat("a", 60000)
|
||||
|
||||
node := resources.Resource{
|
||||
ID: "node-1",
|
||||
Name: "node-1",
|
||||
DisplayName: largeName,
|
||||
Type: resources.ResourceTypeNode,
|
||||
PlatformType: resources.PlatformProxmoxPVE,
|
||||
Status: resources.StatusOnline,
|
||||
}
|
||||
|
||||
stats := resources.StoreStats{
|
||||
TotalResources: 1,
|
||||
ByType: map[resources.ResourceType]int{
|
||||
resources.ResourceTypeNode: 1,
|
||||
},
|
||||
}
|
||||
|
||||
mockRP := &mockResourceProvider{
|
||||
getStatsFunc: func() resources.StoreStats {
|
||||
return stats
|
||||
},
|
||||
getInfrastructureFunc: func() []resources.Resource {
|
||||
return []resources.Resource{node}
|
||||
},
|
||||
getWorkloadsFunc: func() []resources.Resource {
|
||||
return nil
|
||||
},
|
||||
getAllFunc: func() []resources.Resource {
|
||||
return []resources.Resource{node}
|
||||
},
|
||||
getSummaryFunc: func() resources.ResourceSummary {
|
||||
return resources.ResourceSummary{TotalResources: 1}
|
||||
},
|
||||
}
|
||||
|
||||
s := &Service{resourceProvider: mockRP}
|
||||
got := s.buildUnifiedResourceContext()
|
||||
if !strings.Contains(got, "[... Context truncated ...]") {
|
||||
t.Fatal("expected context to be truncated")
|
||||
}
|
||||
if len(got) <= 50000 {
|
||||
t.Fatalf("expected truncated context length > 50000, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func metricValue(current float64) *resources.MetricValue {
|
||||
return &resources.MetricValue{Current: current}
|
||||
}
|
||||
@@ -1,9 +1,14 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/agentexec"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
)
|
||||
|
||||
func TestExtractVMIDFromTargetID(t *testing.T) {
|
||||
@@ -16,21 +21,21 @@ func TestExtractVMIDFromTargetID(t *testing.T) {
|
||||
{"plain vmid", "106", 106},
|
||||
{"node-vmid", "minipc-106", 106},
|
||||
{"instance-node-vmid", "delly-minipc-106", 106},
|
||||
|
||||
|
||||
// Edge cases with hyphenated names
|
||||
{"hyphenated-node-vmid", "pve-node-01-106", 106},
|
||||
{"hyphenated-instance-node-vmid", "my-cluster-pve-node-01-106", 106},
|
||||
|
||||
|
||||
// Type prefixes
|
||||
{"lxc prefix", "lxc-106", 106},
|
||||
{"vm prefix", "vm-106", 106},
|
||||
{"ct prefix", "ct-106", 106},
|
||||
|
||||
|
||||
// Non-numeric - should return 0
|
||||
{"non-numeric", "mycontainer", 0},
|
||||
{"no-vmid", "node-name", 0},
|
||||
{"empty", "", 0},
|
||||
|
||||
|
||||
// Large VMIDs (Proxmox uses up to 999999999)
|
||||
{"large vmid", "node-999999", 999999},
|
||||
}
|
||||
@@ -53,18 +58,18 @@ func TestRoutingError(t *testing.T) {
|
||||
Reason: "No agent connected to node \"minipc\"",
|
||||
Suggestion: "Install pulse-agent on minipc",
|
||||
}
|
||||
|
||||
|
||||
want := "No agent connected to node \"minipc\". Install pulse-agent on minipc"
|
||||
if err.Error() != want {
|
||||
t.Errorf("Error() = %q, want %q", err.Error(), want)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
t.Run("without suggestion", func(t *testing.T) {
|
||||
err := &RoutingError{
|
||||
Reason: "No agents connected",
|
||||
}
|
||||
|
||||
|
||||
want := "No agents connected"
|
||||
if err.Error() != want {
|
||||
t.Errorf("Error() = %q, want %q", err.Error(), want)
|
||||
@@ -74,22 +79,22 @@ func TestRoutingError(t *testing.T) {
|
||||
|
||||
func TestRouteToAgent_NoAgents(t *testing.T) {
|
||||
s := &Service{}
|
||||
|
||||
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
TargetID: "minipc-106",
|
||||
}
|
||||
|
||||
|
||||
_, err := s.routeToAgent(req, "pct exec 106 -- hostname", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for no agents, got nil")
|
||||
}
|
||||
|
||||
|
||||
routingErr, ok := err.(*RoutingError)
|
||||
if !ok {
|
||||
t.Fatalf("expected RoutingError, got %T", err)
|
||||
}
|
||||
|
||||
|
||||
if routingErr.Suggestion == "" {
|
||||
t.Error("expected suggestion in error")
|
||||
}
|
||||
@@ -97,19 +102,19 @@ func TestRouteToAgent_NoAgents(t *testing.T) {
|
||||
|
||||
func TestRouteToAgent_ExactMatch(t *testing.T) {
|
||||
s := &Service{}
|
||||
|
||||
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "delly"},
|
||||
{AgentID: "agent-2", Hostname: "minipc"},
|
||||
{AgentID: "agent-3", Hostname: "pimox"},
|
||||
}
|
||||
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req ExecuteRequest
|
||||
command string
|
||||
wantAgentID string
|
||||
wantHostname string
|
||||
name string
|
||||
req ExecuteRequest
|
||||
command string
|
||||
wantAgentID string
|
||||
wantHostname string
|
||||
}{
|
||||
{
|
||||
name: "route by context node",
|
||||
@@ -151,11 +156,11 @@ func TestRouteToAgent_ExactMatch(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if result.AgentID != tt.wantAgentID {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, tt.wantAgentID)
|
||||
}
|
||||
|
||||
|
||||
if result.AgentHostname != tt.wantHostname {
|
||||
t.Errorf("AgentHostname = %q, want %q", result.AgentHostname, tt.wantHostname)
|
||||
}
|
||||
@@ -165,28 +170,28 @@ func TestRouteToAgent_ExactMatch(t *testing.T) {
|
||||
|
||||
func TestRouteToAgent_NoSubstringMatching(t *testing.T) {
|
||||
s := &Service{}
|
||||
|
||||
|
||||
// Agent "mini" should NOT match node "minipc"
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "mini"},
|
||||
{AgentID: "agent-2", Hostname: "pc"},
|
||||
}
|
||||
|
||||
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{"node": "minipc"},
|
||||
}
|
||||
|
||||
|
||||
_, err := s.routeToAgent(req, "hostname", agents)
|
||||
if err == nil {
|
||||
t.Error("expected error when no exact match, got nil (substring matching may be occurring)")
|
||||
}
|
||||
|
||||
|
||||
routingErr, ok := err.(*RoutingError)
|
||||
if !ok {
|
||||
t.Fatalf("expected RoutingError, got %T", err)
|
||||
}
|
||||
|
||||
|
||||
if routingErr.TargetNode != "minipc" {
|
||||
t.Errorf("TargetNode = %q, want %q", routingErr.TargetNode, "minipc")
|
||||
}
|
||||
@@ -194,21 +199,21 @@ func TestRouteToAgent_NoSubstringMatching(t *testing.T) {
|
||||
|
||||
func TestRouteToAgent_CaseInsensitive(t *testing.T) {
|
||||
s := &Service{}
|
||||
|
||||
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "MiniPC"},
|
||||
}
|
||||
|
||||
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{"node": "minipc"}, // lowercase
|
||||
}
|
||||
|
||||
|
||||
result, err := s.routeToAgent(req, "hostname", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("expected case-insensitive match, got error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if result.AgentID != "agent-1" {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1")
|
||||
}
|
||||
@@ -216,22 +221,22 @@ func TestRouteToAgent_CaseInsensitive(t *testing.T) {
|
||||
|
||||
func TestRouteToAgent_HyphenatedNodeNames(t *testing.T) {
|
||||
s := &Service{}
|
||||
|
||||
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "pve-node-01"},
|
||||
{AgentID: "agent-2", Hostname: "pve-node-02"},
|
||||
}
|
||||
|
||||
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{"node": "pve-node-02"},
|
||||
}
|
||||
|
||||
|
||||
result, err := s.routeToAgent(req, "hostname", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for hyphenated node names: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if result.AgentID != "agent-2" {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-2")
|
||||
}
|
||||
@@ -239,38 +244,441 @@ func TestRouteToAgent_HyphenatedNodeNames(t *testing.T) {
|
||||
|
||||
func TestRouteToAgent_ActionableErrorMessages(t *testing.T) {
|
||||
s := &Service{}
|
||||
|
||||
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "delly"},
|
||||
}
|
||||
|
||||
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{"node": "minipc"},
|
||||
}
|
||||
|
||||
|
||||
_, err := s.routeToAgent(req, "hostname", agents)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
|
||||
routingErr, ok := err.(*RoutingError)
|
||||
if !ok {
|
||||
t.Fatalf("expected RoutingError, got %T", err)
|
||||
}
|
||||
|
||||
|
||||
// Error should mention the target node
|
||||
if routingErr.TargetNode != "minipc" {
|
||||
t.Errorf("TargetNode = %q, want %q", routingErr.TargetNode, "minipc")
|
||||
}
|
||||
|
||||
|
||||
// Error should list available agents
|
||||
if len(routingErr.AvailableAgents) == 0 {
|
||||
t.Error("expected available agents in error")
|
||||
}
|
||||
|
||||
|
||||
// Error should have actionable suggestion
|
||||
if routingErr.Suggestion == "" {
|
||||
t.Error("expected suggestion in error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingError_ForAI(t *testing.T) {
|
||||
t.Run("clarification", func(t *testing.T) {
|
||||
err := &RoutingError{
|
||||
Reason: "Cannot determine which host should execute this command",
|
||||
AvailableAgents: []string{"delly", "pimox"},
|
||||
AskForClarification: true,
|
||||
}
|
||||
|
||||
msg := err.ForAI()
|
||||
if !strings.Contains(msg, "ROUTING_CLARIFICATION_NEEDED") {
|
||||
t.Errorf("expected clarification marker, got %q", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "delly, pimox") {
|
||||
t.Errorf("expected available hosts list, got %q", msg)
|
||||
}
|
||||
if !strings.Contains(msg, err.Reason) {
|
||||
t.Errorf("expected reason to be included, got %q", msg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback", func(t *testing.T) {
|
||||
err := &RoutingError{
|
||||
Reason: "No agents connected",
|
||||
AskForClarification: true,
|
||||
}
|
||||
if err.ForAI() != err.Error() {
|
||||
t.Errorf("expected ForAI to fall back to Error, got %q", err.ForAI())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouteToAgent_VMIDRoutingWithInstance(t *testing.T) {
|
||||
stateProvider := &mockStateProvider{
|
||||
state: models.StateSnapshot{
|
||||
Containers: []models.Container{
|
||||
{VMID: 106, Node: "node-b", Name: "ct-b", Instance: "cluster-b"},
|
||||
{VMID: 106, Node: "node-a", Name: "ct-a", Instance: "cluster-b"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := &Service{stateProvider: stateProvider}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-a", Hostname: "node-a"},
|
||||
{AgentID: "agent-b", Hostname: "node-b"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{"instance": "cluster-b"},
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "pct exec 106 -- hostname", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.AgentID != "agent-b" {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-b")
|
||||
}
|
||||
if result.RoutingMethod != "vmid_lookup_with_instance" {
|
||||
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "vmid_lookup_with_instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_VMIDCollision(t *testing.T) {
|
||||
stateProvider := &mockStateProvider{
|
||||
state: models.StateSnapshot{
|
||||
VMs: []models.VM{
|
||||
{VMID: 200, Node: "node-a", Name: "vm-a", Instance: "cluster-a"},
|
||||
{VMID: 200, Node: "node-b", Name: "vm-b", Instance: "cluster-b"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := &Service{stateProvider: stateProvider}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-a", Hostname: "node-a"},
|
||||
{AgentID: "agent-b", Hostname: "node-b"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "vm",
|
||||
}
|
||||
|
||||
_, err := s.routeToAgent(req, "pct exec 200 -- hostname", agents)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for VMID collision, got nil")
|
||||
}
|
||||
routingErr, ok := err.(*RoutingError)
|
||||
if !ok {
|
||||
t.Fatalf("expected RoutingError, got %T", err)
|
||||
}
|
||||
if routingErr.TargetVMID != 200 {
|
||||
t.Errorf("TargetVMID = %d, want %d", routingErr.TargetVMID, 200)
|
||||
}
|
||||
if !strings.Contains(routingErr.Reason, "exists on multiple nodes") {
|
||||
t.Errorf("expected collision reason, got %q", routingErr.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_VMIDNotFoundFallsBackToContext(t *testing.T) {
|
||||
stateProvider := &mockStateProvider{}
|
||||
s := &Service{stateProvider: stateProvider}
|
||||
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "minipc"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{"node": "minipc"},
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "pct exec 106 -- hostname", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.RoutingMethod != "context_node" {
|
||||
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "context_node")
|
||||
}
|
||||
if len(result.Warnings) != 1 || !strings.Contains(result.Warnings[0], "VMID 106 not found") {
|
||||
t.Errorf("expected warning about missing VMID, got %v", result.Warnings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_ResourceProviderContexts(t *testing.T) {
|
||||
s := &Service{}
|
||||
mockRP := &mockResourceProvider{
|
||||
findContainerHostFunc: func(containerNameOrID string) string {
|
||||
if containerNameOrID == "" {
|
||||
return ""
|
||||
}
|
||||
return "rp-host"
|
||||
},
|
||||
}
|
||||
s.resourceProvider = mockRP
|
||||
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "rp-host"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
}{
|
||||
{name: "containerName", key: "containerName"},
|
||||
{name: "name", key: "name"},
|
||||
{name: "guestName", key: "guestName"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := ExecuteRequest{
|
||||
TargetType: "container",
|
||||
Context: map[string]interface{}{tt.key: "workload"},
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "docker ps", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.RoutingMethod != "resource_provider_lookup" {
|
||||
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "resource_provider_lookup")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_TargetIDLookup(t *testing.T) {
|
||||
stateProvider := &mockStateProvider{
|
||||
state: models.StateSnapshot{
|
||||
VMs: []models.VM{
|
||||
{VMID: 222, Node: "node-vm", Name: "vm-222", Instance: "cluster-a"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := &Service{stateProvider: stateProvider}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "node-vm"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "vm",
|
||||
TargetID: "vm-222",
|
||||
Context: map[string]interface{}{"instance": "cluster-a"},
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "status", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.RoutingMethod != "target_id_vmid_lookup" {
|
||||
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "target_id_vmid_lookup")
|
||||
}
|
||||
if result.AgentID != "agent-1" {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_VMIDLookupSingleMatch(t *testing.T) {
|
||||
stateProvider := &mockStateProvider{
|
||||
state: models.StateSnapshot{
|
||||
VMs: []models.VM{
|
||||
{VMID: 101, Node: "node-1", Name: "vm-one"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := &Service{stateProvider: stateProvider}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "node-1"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "vm",
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "qm start 101", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.RoutingMethod != "vmid_lookup" {
|
||||
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "vmid_lookup")
|
||||
}
|
||||
if result.AgentID != "agent-1" {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_ClusterPeer(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
persistence := config.NewConfigPersistence(tmp)
|
||||
err := persistence.SaveNodesConfig([]config.PVEInstance{
|
||||
{
|
||||
Name: "node-a",
|
||||
IsCluster: true,
|
||||
ClusterName: "cluster-a",
|
||||
ClusterEndpoints: []config.ClusterEndpoint{
|
||||
{NodeName: "node-a"},
|
||||
{NodeName: "node-b"},
|
||||
},
|
||||
},
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveNodesConfig: %v", err)
|
||||
}
|
||||
|
||||
s := &Service{persistence: persistence}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-b", Hostname: "node-b"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "vm",
|
||||
Context: map[string]interface{}{"node": "node-a"},
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "hostname", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !result.ClusterPeer {
|
||||
t.Error("expected cluster peer routing")
|
||||
}
|
||||
if result.AgentID != "agent-b" {
|
||||
t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-b")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_SingleAgentFallbackForHost(t *testing.T) {
|
||||
s := &Service{}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "only"},
|
||||
}
|
||||
req := ExecuteRequest{
|
||||
TargetType: "host",
|
||||
}
|
||||
|
||||
result, err := s.routeToAgent(req, "uptime", agents)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.RoutingMethod != "single_agent_fallback" {
|
||||
t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "single_agent_fallback")
|
||||
}
|
||||
if len(result.Warnings) != 1 {
|
||||
t.Errorf("expected warning for single agent fallback, got %v", result.Warnings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToAgent_AsksForClarification(t *testing.T) {
|
||||
s := &Service{}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "delly"},
|
||||
{AgentID: "agent-2", Hostname: "pimox"},
|
||||
}
|
||||
|
||||
req := ExecuteRequest{
|
||||
TargetType: "vm",
|
||||
}
|
||||
|
||||
_, err := s.routeToAgent(req, "hostname", agents)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
routingErr, ok := err.(*RoutingError)
|
||||
if !ok {
|
||||
t.Fatalf("expected RoutingError, got %T", err)
|
||||
}
|
||||
if !routingErr.AskForClarification {
|
||||
t.Error("expected AskForClarification to be true")
|
||||
}
|
||||
if len(routingErr.AvailableAgents) != 2 {
|
||||
t.Errorf("expected available agents, got %v", routingErr.AvailableAgents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindClusterPeerAgent_NoPersistence(t *testing.T) {
|
||||
s := &Service{}
|
||||
if got := s.findClusterPeerAgent("node-a", nil); got != "" {
|
||||
t.Errorf("expected empty result, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindClusterPeerAgent_LoadError(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
persistence := config.NewConfigPersistence(tmp)
|
||||
nodesPath := filepath.Join(tmp, "nodes.enc")
|
||||
if err := os.Mkdir(nodesPath, 0700); err != nil {
|
||||
t.Fatalf("Mkdir nodes.enc: %v", err)
|
||||
}
|
||||
|
||||
s := &Service{persistence: persistence}
|
||||
if got := s.findClusterPeerAgent("node-a", nil); got != "" {
|
||||
t.Errorf("expected empty result, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindClusterPeerAgent_NotCluster(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
persistence := config.NewConfigPersistence(tmp)
|
||||
err := persistence.SaveNodesConfig([]config.PVEInstance{
|
||||
{Name: "node-a"},
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveNodesConfig: %v", err)
|
||||
}
|
||||
|
||||
s := &Service{persistence: persistence}
|
||||
if got := s.findClusterPeerAgent("node-a", nil); got != "" {
|
||||
t.Errorf("expected empty result, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindClusterPeerAgent_NoAgentMatch(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
persistence := config.NewConfigPersistence(tmp)
|
||||
err := persistence.SaveNodesConfig([]config.PVEInstance{
|
||||
{
|
||||
Name: "node-a",
|
||||
IsCluster: true,
|
||||
ClusterName: "cluster-a",
|
||||
ClusterEndpoints: []config.ClusterEndpoint{
|
||||
{NodeName: "node-a"},
|
||||
{NodeName: "node-b"},
|
||||
},
|
||||
},
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveNodesConfig: %v", err)
|
||||
}
|
||||
|
||||
s := &Service{persistence: persistence}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-1", Hostname: "node-c"},
|
||||
}
|
||||
if got := s.findClusterPeerAgent("node-a", agents); got != "" {
|
||||
t.Errorf("expected empty result, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindClusterPeerAgent_EndpointMatch(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
persistence := config.NewConfigPersistence(tmp)
|
||||
err := persistence.SaveNodesConfig([]config.PVEInstance{
|
||||
{
|
||||
Name: "cluster-master",
|
||||
IsCluster: true,
|
||||
ClusterName: "cluster-a",
|
||||
ClusterEndpoints: []config.ClusterEndpoint{
|
||||
{NodeName: "node-a"},
|
||||
{NodeName: "node-b"},
|
||||
},
|
||||
},
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveNodesConfig: %v", err)
|
||||
}
|
||||
|
||||
s := &Service{persistence: persistence}
|
||||
agents := []agentexec.ConnectedAgent{
|
||||
{AgentID: "agent-b", Hostname: "node-b"},
|
||||
}
|
||||
if got := s.findClusterPeerAgent("node-a", agents); got != "agent-b" {
|
||||
t.Errorf("expected peer agent, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,24 +12,33 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var commandRunner = func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
return stdout.Bytes(), stderr.Bytes(), err
|
||||
}
|
||||
|
||||
// ClusterStatus represents the complete Ceph cluster status as collected by the agent.
|
||||
type ClusterStatus struct {
|
||||
FSID string `json:"fsid"`
|
||||
Health HealthStatus `json:"health"`
|
||||
MonMap MonitorMap `json:"monMap,omitempty"`
|
||||
MgrMap ManagerMap `json:"mgrMap,omitempty"`
|
||||
OSDMap OSDMap `json:"osdMap"`
|
||||
PGMap PGMap `json:"pgMap"`
|
||||
Pools []Pool `json:"pools,omitempty"`
|
||||
Services []ServiceInfo `json:"services,omitempty"`
|
||||
CollectedAt time.Time `json:"collectedAt"`
|
||||
FSID string `json:"fsid"`
|
||||
Health HealthStatus `json:"health"`
|
||||
MonMap MonitorMap `json:"monMap,omitempty"`
|
||||
MgrMap ManagerMap `json:"mgrMap,omitempty"`
|
||||
OSDMap OSDMap `json:"osdMap"`
|
||||
PGMap PGMap `json:"pgMap"`
|
||||
Pools []Pool `json:"pools,omitempty"`
|
||||
Services []ServiceInfo `json:"services,omitempty"`
|
||||
CollectedAt time.Time `json:"collectedAt"`
|
||||
}
|
||||
|
||||
// HealthStatus represents Ceph cluster health.
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"` // HEALTH_OK, HEALTH_WARN, HEALTH_ERR
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Summary []HealthSummary `json:"summary,omitempty"`
|
||||
Status string `json:"status"` // HEALTH_OK, HEALTH_WARN, HEALTH_ERR
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Summary []HealthSummary `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// Check represents a health check detail.
|
||||
@@ -70,28 +79,28 @@ type ManagerMap struct {
|
||||
|
||||
// OSDMap represents OSD status summary.
|
||||
type OSDMap struct {
|
||||
Epoch int `json:"epoch"`
|
||||
NumOSDs int `json:"numOsds"`
|
||||
NumUp int `json:"numUp"`
|
||||
NumIn int `json:"numIn"`
|
||||
NumDown int `json:"numDown,omitempty"`
|
||||
NumOut int `json:"numOut,omitempty"`
|
||||
Epoch int `json:"epoch"`
|
||||
NumOSDs int `json:"numOsds"`
|
||||
NumUp int `json:"numUp"`
|
||||
NumIn int `json:"numIn"`
|
||||
NumDown int `json:"numDown,omitempty"`
|
||||
NumOut int `json:"numOut,omitempty"`
|
||||
}
|
||||
|
||||
// PGMap represents placement group statistics.
|
||||
type PGMap struct {
|
||||
NumPGs int `json:"numPgs"`
|
||||
BytesTotal uint64 `json:"bytesTotal"`
|
||||
BytesUsed uint64 `json:"bytesUsed"`
|
||||
BytesAvailable uint64 `json:"bytesAvailable"`
|
||||
DataBytes uint64 `json:"dataBytes,omitempty"`
|
||||
UsagePercent float64 `json:"usagePercent"`
|
||||
DegradedRatio float64 `json:"degradedRatio,omitempty"`
|
||||
MisplacedRatio float64 `json:"misplacedRatio,omitempty"`
|
||||
ReadBytesPerSec uint64 `json:"readBytesPerSec,omitempty"`
|
||||
WriteBytesPerSec uint64 `json:"writeBytesPerSec,omitempty"`
|
||||
ReadOpsPerSec uint64 `json:"readOpsPerSec,omitempty"`
|
||||
WriteOpsPerSec uint64 `json:"writeOpsPerSec,omitempty"`
|
||||
NumPGs int `json:"numPgs"`
|
||||
BytesTotal uint64 `json:"bytesTotal"`
|
||||
BytesUsed uint64 `json:"bytesUsed"`
|
||||
BytesAvailable uint64 `json:"bytesAvailable"`
|
||||
DataBytes uint64 `json:"dataBytes,omitempty"`
|
||||
UsagePercent float64 `json:"usagePercent"`
|
||||
DegradedRatio float64 `json:"degradedRatio,omitempty"`
|
||||
MisplacedRatio float64 `json:"misplacedRatio,omitempty"`
|
||||
ReadBytesPerSec uint64 `json:"readBytesPerSec,omitempty"`
|
||||
WriteBytesPerSec uint64 `json:"writeBytesPerSec,omitempty"`
|
||||
ReadOpsPerSec uint64 `json:"readOpsPerSec,omitempty"`
|
||||
WriteOpsPerSec uint64 `json:"writeOpsPerSec,omitempty"`
|
||||
}
|
||||
|
||||
// Pool represents a Ceph pool.
|
||||
@@ -106,16 +115,16 @@ type Pool struct {
|
||||
|
||||
// ServiceInfo represents a Ceph service summary.
|
||||
type ServiceInfo struct {
|
||||
Type string `json:"type"` // mon, mgr, osd, mds, rgw
|
||||
Running int `json:"running"`
|
||||
Total int `json:"total"`
|
||||
Daemons []string `json:"daemons,omitempty"`
|
||||
Type string `json:"type"` // mon, mgr, osd, mds, rgw
|
||||
Running int `json:"running"`
|
||||
Total int `json:"total"`
|
||||
Daemons []string `json:"daemons,omitempty"`
|
||||
}
|
||||
|
||||
// IsAvailable checks if the ceph CLI is available on the system.
|
||||
func IsAvailable(ctx context.Context) bool {
|
||||
cmd := exec.CommandContext(ctx, "which", "ceph")
|
||||
return cmd.Run() == nil
|
||||
_, _, err := commandRunner(ctx, "which", "ceph")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Collect gathers Ceph cluster status using the ceph CLI.
|
||||
@@ -160,17 +169,13 @@ func runCephCommand(ctx context.Context, args ...string) ([]byte, error) {
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "ceph", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ceph %s failed: %w (stderr: %s)",
|
||||
strings.Join(args, " "), err, stderr.String())
|
||||
stdout, stderr, err := commandRunner(cmdCtx, "ceph", args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ceph %s failed: %w (stderr: %s)",
|
||||
strings.Join(args, " "), err, string(stderr))
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
return stdout, nil
|
||||
}
|
||||
|
||||
// parseStatus parses the output of `ceph status --format json`.
|
||||
@@ -178,8 +183,8 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
|
||||
var raw struct {
|
||||
FSID string `json:"fsid"`
|
||||
Health struct {
|
||||
Status string `json:"status"`
|
||||
Checks map[string]struct {
|
||||
Status string `json:"status"`
|
||||
Checks map[string]struct {
|
||||
Severity string `json:"severity"`
|
||||
Summary struct {
|
||||
Message string `json:"message"`
|
||||
@@ -198,10 +203,10 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
|
||||
} `json:"mons"`
|
||||
} `json:"monmap"`
|
||||
MgrMap struct {
|
||||
Available bool `json:"available"`
|
||||
NumActive int `json:"num_active_name,omitempty"`
|
||||
Available bool `json:"available"`
|
||||
NumActive int `json:"num_active_name,omitempty"`
|
||||
ActiveName string `json:"active_name"`
|
||||
Standbys []struct {
|
||||
Standbys []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"standbys"`
|
||||
} `json:"mgrmap"`
|
||||
@@ -212,17 +217,17 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
|
||||
NumIn int `json:"num_in_osds"`
|
||||
} `json:"osdmap"`
|
||||
PGMap struct {
|
||||
NumPGs int `json:"num_pgs"`
|
||||
BytesTotal uint64 `json:"bytes_total"`
|
||||
BytesUsed uint64 `json:"bytes_used"`
|
||||
BytesAvail uint64 `json:"bytes_avail"`
|
||||
DataBytes uint64 `json:"data_bytes"`
|
||||
DegradedRatio float64 `json:"degraded_ratio"`
|
||||
MisplacedRatio float64 `json:"misplaced_ratio"`
|
||||
ReadBytesPerSec uint64 `json:"read_bytes_sec"`
|
||||
WriteBytesPerSec uint64 `json:"write_bytes_sec"`
|
||||
ReadOpsPerSec uint64 `json:"read_op_per_sec"`
|
||||
WriteOpsPerSec uint64 `json:"write_op_per_sec"`
|
||||
NumPGs int `json:"num_pgs"`
|
||||
BytesTotal uint64 `json:"bytes_total"`
|
||||
BytesUsed uint64 `json:"bytes_used"`
|
||||
BytesAvail uint64 `json:"bytes_avail"`
|
||||
DataBytes uint64 `json:"data_bytes"`
|
||||
DegradedRatio float64 `json:"degraded_ratio"`
|
||||
MisplacedRatio float64 `json:"misplaced_ratio"`
|
||||
ReadBytesPerSec uint64 `json:"read_bytes_sec"`
|
||||
WriteBytesPerSec uint64 `json:"write_bytes_sec"`
|
||||
ReadOpsPerSec uint64 `json:"read_op_per_sec"`
|
||||
WriteOpsPerSec uint64 `json:"write_op_per_sec"`
|
||||
} `json:"pgmap"`
|
||||
}
|
||||
|
||||
@@ -255,13 +260,13 @@ func parseStatus(data []byte) (*ClusterStatus, error) {
|
||||
NumOut: raw.OSDMap.NumOSD - raw.OSDMap.NumIn,
|
||||
},
|
||||
PGMap: PGMap{
|
||||
NumPGs: raw.PGMap.NumPGs,
|
||||
BytesTotal: raw.PGMap.BytesTotal,
|
||||
BytesUsed: raw.PGMap.BytesUsed,
|
||||
BytesAvailable: raw.PGMap.BytesAvail,
|
||||
DataBytes: raw.PGMap.DataBytes,
|
||||
DegradedRatio: raw.PGMap.DegradedRatio,
|
||||
MisplacedRatio: raw.PGMap.MisplacedRatio,
|
||||
NumPGs: raw.PGMap.NumPGs,
|
||||
BytesTotal: raw.PGMap.BytesTotal,
|
||||
BytesUsed: raw.PGMap.BytesUsed,
|
||||
BytesAvailable: raw.PGMap.BytesAvail,
|
||||
DataBytes: raw.PGMap.DataBytes,
|
||||
DegradedRatio: raw.PGMap.DegradedRatio,
|
||||
MisplacedRatio: raw.PGMap.MisplacedRatio,
|
||||
ReadBytesPerSec: raw.PGMap.ReadBytesPerSec,
|
||||
WriteBytesPerSec: raw.PGMap.WriteBytesPerSec,
|
||||
ReadOpsPerSec: raw.PGMap.ReadOpsPerSec,
|
||||
|
||||
@@ -1,9 +1,32 @@
|
||||
package ceph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func withCommandRunner(t *testing.T, fn func(ctx context.Context, name string, args ...string) ([]byte, []byte, error)) {
|
||||
t.Helper()
|
||||
orig := commandRunner
|
||||
commandRunner = fn
|
||||
t.Cleanup(func() { commandRunner = orig })
|
||||
}
|
||||
|
||||
func TestCommandRunner_Default(t *testing.T) {
|
||||
stdout, stderr, err := commandRunner(context.Background(), "sh", "-c", "true")
|
||||
if err != nil {
|
||||
t.Fatalf("commandRunner error: %v", err)
|
||||
}
|
||||
if len(stdout) != 0 {
|
||||
t.Fatalf("unexpected stdout: %q", string(stdout))
|
||||
}
|
||||
if len(stderr) != 0 {
|
||||
t.Fatalf("unexpected stderr: %q", string(stderr))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStatus(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"fsid":"fsid-123",
|
||||
@@ -70,6 +93,68 @@ func TestParseStatus(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAvailable(t *testing.T) {
|
||||
t.Run("available", func(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name != "which" || len(args) != 1 || args[0] != "ceph" {
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
}
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
if !IsAvailable(context.Background()) {
|
||||
t.Fatalf("expected available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
return nil, nil, errors.New("missing")
|
||||
})
|
||||
|
||||
if IsAvailable(context.Background()) {
|
||||
t.Fatalf("expected unavailable")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunCephCommand(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name != "ceph" {
|
||||
t.Fatalf("unexpected command name: %s", name)
|
||||
}
|
||||
if len(args) < 1 || args[0] != "status" {
|
||||
t.Fatalf("unexpected args: %v", args)
|
||||
}
|
||||
return []byte(`{"ok":true}`), nil, nil
|
||||
})
|
||||
|
||||
out, err := runCephCommand(context.Background(), "status", "--format", "json")
|
||||
if err != nil {
|
||||
t.Fatalf("runCephCommand error: %v", err)
|
||||
}
|
||||
if string(out) != `{"ok":true}` {
|
||||
t.Fatalf("unexpected output: %s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCephCommandError(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
return nil, []byte("bad"), errors.New("boom")
|
||||
})
|
||||
|
||||
_, err := runCephCommand(context.Background(), "status", "--format", "json")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ceph status --format json failed") {
|
||||
t.Fatalf("unexpected error message: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "stderr: bad") {
|
||||
t.Fatalf("expected stderr in error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStatusInvalidJSON(t *testing.T) {
|
||||
_, err := parseStatus([]byte(`{not-json}`))
|
||||
if err == nil {
|
||||
@@ -104,6 +189,181 @@ func TestParseDFInvalidJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect_NotAvailable(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name == "which" {
|
||||
return nil, nil, errors.New("missing")
|
||||
}
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
status, err := Collect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if status != nil {
|
||||
t.Fatalf("expected nil status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect_StatusError(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name == "which" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "status" {
|
||||
return nil, []byte("boom"), errors.New("status failed")
|
||||
}
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
status, err := Collect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if status != nil {
|
||||
t.Fatalf("expected nil status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect_ParseStatusError(t *testing.T) {
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name == "which" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "status" {
|
||||
return []byte(`{not-json}`), nil, nil
|
||||
}
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
_, err := Collect(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("expected parse error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect_DFCommandError(t *testing.T) {
|
||||
statusJSON := []byte(`{
|
||||
"fsid":"fsid-1",
|
||||
"health":{"status":"HEALTH_OK","checks":{}},
|
||||
"monmap":{"epoch":1,"mons":[]},
|
||||
"mgrmap":{"available":false,"active_name":"","standbys":[]},
|
||||
"osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0},
|
||||
"pgmap":{"num_pgs":0,"bytes_total":100,"bytes_used":50,"bytes_avail":50,
|
||||
"data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0,
|
||||
"read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0}
|
||||
}`)
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name == "which" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "status" {
|
||||
return statusJSON, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "df" {
|
||||
return nil, []byte("df failed"), errors.New("df error")
|
||||
}
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
status, err := Collect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if status == nil {
|
||||
t.Fatalf("expected status")
|
||||
}
|
||||
if status.PGMap.UsagePercent == 0 {
|
||||
t.Fatalf("expected usage percent from status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect_DFParseError(t *testing.T) {
|
||||
statusJSON := []byte(`{
|
||||
"fsid":"fsid-1",
|
||||
"health":{"status":"HEALTH_OK","checks":{}},
|
||||
"monmap":{"epoch":1,"mons":[]},
|
||||
"mgrmap":{"available":false,"active_name":"","standbys":[]},
|
||||
"osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0},
|
||||
"pgmap":{"num_pgs":0,"bytes_total":0,"bytes_used":0,"bytes_avail":0,
|
||||
"data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0,
|
||||
"read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0}
|
||||
}`)
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name == "which" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "status" {
|
||||
return statusJSON, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "df" {
|
||||
return []byte(`{not-json}`), nil, nil
|
||||
}
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
status, err := Collect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if status == nil {
|
||||
t.Fatalf("expected status")
|
||||
}
|
||||
if status.PGMap.UsagePercent != 0 {
|
||||
t.Fatalf("expected usage percent to remain 0, got %v", status.PGMap.UsagePercent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect_UsagePercentFromDF(t *testing.T) {
|
||||
statusJSON := []byte(`{
|
||||
"fsid":"fsid-1",
|
||||
"health":{"status":"HEALTH_OK","checks":{}},
|
||||
"monmap":{"epoch":1,"mons":[]},
|
||||
"mgrmap":{"available":false,"active_name":"","standbys":[]},
|
||||
"osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0},
|
||||
"pgmap":{"num_pgs":0,"bytes_total":0,"bytes_used":0,"bytes_avail":0,
|
||||
"data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0,
|
||||
"read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0}
|
||||
}`)
|
||||
dfJSON := []byte(`{
|
||||
"stats":{"total_bytes":1000,"total_used_bytes":500,"percent_used":0.5},
|
||||
"pools":[]
|
||||
}`)
|
||||
withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) {
|
||||
if name == "which" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "status" {
|
||||
return statusJSON, nil, nil
|
||||
}
|
||||
if name == "ceph" && len(args) > 0 && args[0] == "df" {
|
||||
return dfJSON, nil, nil
|
||||
}
|
||||
t.Fatalf("unexpected command: %s %v", name, args)
|
||||
return nil, nil, nil
|
||||
})
|
||||
|
||||
status, err := Collect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if status == nil {
|
||||
t.Fatalf("expected status")
|
||||
}
|
||||
if status.PGMap.UsagePercent != 50 {
|
||||
t.Fatalf("expected usage percent from df, got %v", status.PGMap.UsagePercent)
|
||||
}
|
||||
if status.CollectedAt.IsZero() {
|
||||
t.Fatalf("expected collected timestamp set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolToInt(t *testing.T) {
|
||||
if boolToInt(true) != 1 {
|
||||
t.Fatalf("expected boolToInt(true)=1")
|
||||
|
||||
@@ -14,6 +14,16 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var defaultDataDirFn = utils.GetDataDir
|
||||
|
||||
var legacyKeyPath = "/etc/pulse/.encryption.key"
|
||||
|
||||
var randReader = rand.Reader
|
||||
|
||||
var newCipher = aes.NewCipher
|
||||
|
||||
var newGCM = cipher.NewGCM
|
||||
|
||||
// CryptoManager handles encryption/decryption of sensitive data
|
||||
type CryptoManager struct {
|
||||
key []byte
|
||||
@@ -23,7 +33,7 @@ type CryptoManager struct {
|
||||
// NewCryptoManagerAt creates a new crypto manager with an explicit data directory override.
|
||||
func NewCryptoManagerAt(dataDir string) (*CryptoManager, error) {
|
||||
if dataDir == "" {
|
||||
dataDir = utils.GetDataDir()
|
||||
dataDir = defaultDataDirFn()
|
||||
}
|
||||
keyPath := filepath.Join(dataDir, ".encryption.key")
|
||||
|
||||
@@ -41,11 +51,12 @@ func NewCryptoManagerAt(dataDir string) (*CryptoManager, error) {
|
||||
// getOrCreateKeyAt gets the encryption key or creates one if it doesn't exist
|
||||
func getOrCreateKeyAt(dataDir string) ([]byte, error) {
|
||||
if dataDir == "" {
|
||||
dataDir = utils.GetDataDir()
|
||||
dataDir = defaultDataDirFn()
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(dataDir, ".encryption.key")
|
||||
oldKeyPath := "/etc/pulse/.encryption.key"
|
||||
oldKeyPath := legacyKeyPath
|
||||
oldKeyDir := filepath.Dir(oldKeyPath)
|
||||
|
||||
log.Debug().
|
||||
Str("dataDir", dataDir).
|
||||
@@ -78,7 +89,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) {
|
||||
// Check for key in old location and migrate if found (only if paths differ)
|
||||
// CRITICAL: This code deletes the encryption key at oldKeyPath after migrating it.
|
||||
// Adding extensive logging to diagnose recurring key deletion bug.
|
||||
if dataDir != "/etc/pulse" && keyPath != oldKeyPath {
|
||||
if dataDir != oldKeyDir && keyPath != oldKeyPath {
|
||||
log.Warn().
|
||||
Str("dataDir", dataDir).
|
||||
Str("keyPath", keyPath).
|
||||
@@ -135,7 +146,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) {
|
||||
log.Debug().
|
||||
Str("dataDir", dataDir).
|
||||
Str("keyPath", keyPath).
|
||||
Bool("sameAsOldPath", dataDir == "/etc/pulse").
|
||||
Bool("sameAsOldPath", dataDir == oldKeyDir).
|
||||
Msg("Skipping key migration check (dataDir is /etc/pulse or paths match)")
|
||||
}
|
||||
|
||||
@@ -171,7 +182,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) {
|
||||
|
||||
// Generate new key (only if no encrypted data exists)
|
||||
key := make([]byte, 32) // AES-256
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
if _, err := io.ReadFull(randReader, key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
@@ -205,18 +216,18 @@ func (c *CryptoManager) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(c.key)
|
||||
block, err := newCipher(c.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
gcm, err := newGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
if _, err := io.ReadFull(randReader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -226,12 +237,12 @@ func (c *CryptoManager) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
|
||||
// Decrypt decrypts data using AES-GCM
|
||||
func (c *CryptoManager) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(c.key)
|
||||
block, err := newCipher(c.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
gcm, err := newGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2,11 +2,51 @@ package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type errReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errReader) Read(p []byte) (int, error) {
|
||||
return 0, e.err
|
||||
}
|
||||
|
||||
func withDefaultDataDir(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
orig := defaultDataDirFn
|
||||
defaultDataDirFn = func() string { return dir }
|
||||
t.Cleanup(func() { defaultDataDirFn = orig })
|
||||
}
|
||||
|
||||
func withLegacyKeyPath(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
orig := legacyKeyPath
|
||||
legacyKeyPath = path
|
||||
t.Cleanup(func() { legacyKeyPath = orig })
|
||||
}
|
||||
|
||||
func withRandReader(t *testing.T, r io.Reader) {
|
||||
t.Helper()
|
||||
orig := randReader
|
||||
randReader = r
|
||||
t.Cleanup(func() { randReader = orig })
|
||||
}
|
||||
|
||||
func withNewGCM(t *testing.T, fn func(cipher.Block) (cipher.AEAD, error)) {
|
||||
t.Helper()
|
||||
orig := newGCM
|
||||
newGCM = fn
|
||||
t.Cleanup(func() { newGCM = orig })
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Create a temp directory for the test
|
||||
tmpDir := t.TempDir()
|
||||
@@ -149,6 +189,33 @@ func TestEncryptionKeyFilePermissions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCryptoManagerAt_DefaultDataDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withDefaultDataDir(t, tmpDir)
|
||||
|
||||
cm, err := NewCryptoManagerAt("")
|
||||
if err != nil {
|
||||
t.Fatalf("NewCryptoManagerAt() error: %v", err)
|
||||
}
|
||||
if cm.keyPath != filepath.Join(tmpDir, ".encryption.key") {
|
||||
t.Fatalf("keyPath = %q, want %q", cm.keyPath, filepath.Join(tmpDir, ".encryption.key"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCryptoManagerAt_KeyError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
err := os.WriteFile(filepath.Join(tmpDir, "nodes.enc"), []byte("data"), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create encrypted file: %v", err)
|
||||
}
|
||||
|
||||
_, err = NewCryptoManagerAt(tmpDir)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when encrypted data exists without a key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cm, err := NewCryptoManagerAt(tmpDir)
|
||||
@@ -175,6 +242,217 @@ func TestDecryptInvalidData(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_InvalidBase64(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
|
||||
keyPath := filepath.Join(tmpDir, ".encryption.key")
|
||||
if err := os.WriteFile(keyPath, []byte("not-base64"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write key file: %v", err)
|
||||
}
|
||||
|
||||
key, err := getOrCreateKeyAt(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Fatalf("expected 32-byte key, got %d", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_DefaultDataDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withDefaultDataDir(t, tmpDir)
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
|
||||
key, err := getOrCreateKeyAt("")
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Fatalf("expected 32-byte key, got %d", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_InvalidLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
|
||||
shortKey := make([]byte, 16)
|
||||
for i := range shortKey {
|
||||
shortKey[i] = byte(i)
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(shortKey)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, ".encryption.key"), []byte(encoded), 0600); err != nil {
|
||||
t.Fatalf("Failed to write key file: %v", err)
|
||||
}
|
||||
|
||||
key, err := getOrCreateKeyAt(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Fatalf("expected 32-byte key, got %d", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_SkipMigrationWhenPathsMatch(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(tmpDir, ".encryption.key"))
|
||||
|
||||
key, err := getOrCreateKeyAt(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Fatalf("expected 32-byte key, got %d", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_MigrateSuccess(t *testing.T) {
|
||||
legacyDir := t.TempDir()
|
||||
legacyPath := filepath.Join(legacyDir, ".encryption.key")
|
||||
withLegacyKeyPath(t, legacyPath)
|
||||
|
||||
oldKey := make([]byte, 32)
|
||||
for i := range oldKey {
|
||||
oldKey[i] = byte(i)
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(oldKey)
|
||||
if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil {
|
||||
t.Fatalf("Failed to write legacy key: %v", err)
|
||||
}
|
||||
|
||||
newDir := t.TempDir()
|
||||
key, err := getOrCreateKeyAt(newDir)
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(key, oldKey) {
|
||||
t.Fatalf("migrated key mismatch")
|
||||
}
|
||||
contents, err := os.ReadFile(filepath.Join(newDir, ".encryption.key"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read migrated key: %v", err)
|
||||
}
|
||||
if string(contents) != encoded {
|
||||
t.Fatalf("migrated key contents mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_MigrateMkdirError(t *testing.T) {
|
||||
legacyDir := t.TempDir()
|
||||
legacyPath := filepath.Join(legacyDir, ".encryption.key")
|
||||
withLegacyKeyPath(t, legacyPath)
|
||||
|
||||
oldKey := make([]byte, 32)
|
||||
for i := range oldKey {
|
||||
oldKey[i] = byte(i)
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(oldKey)
|
||||
if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil {
|
||||
t.Fatalf("Failed to write legacy key: %v", err)
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
dataFile := filepath.Join(tmpDir, "datafile")
|
||||
if err := os.WriteFile(dataFile, []byte("x"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write data file: %v", err)
|
||||
}
|
||||
|
||||
key, err := getOrCreateKeyAt(dataFile)
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(key, oldKey) {
|
||||
t.Fatalf("expected legacy key on mkdir error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_MigrateWriteError(t *testing.T) {
|
||||
legacyDir := t.TempDir()
|
||||
legacyPath := filepath.Join(legacyDir, ".encryption.key")
|
||||
withLegacyKeyPath(t, legacyPath)
|
||||
|
||||
oldKey := make([]byte, 32)
|
||||
for i := range oldKey {
|
||||
oldKey[i] = byte(i)
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(oldKey)
|
||||
if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil {
|
||||
t.Fatalf("Failed to write legacy key: %v", err)
|
||||
}
|
||||
|
||||
newDir := t.TempDir()
|
||||
keyPath := filepath.Join(newDir, ".encryption.key")
|
||||
if err := os.MkdirAll(keyPath, 0700); err != nil {
|
||||
t.Fatalf("Failed to create key path dir: %v", err)
|
||||
}
|
||||
|
||||
key, err := getOrCreateKeyAt(newDir)
|
||||
if err != nil {
|
||||
t.Fatalf("getOrCreateKeyAt() error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(key, oldKey) {
|
||||
t.Fatalf("expected legacy key on write error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_EncryptedDataExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "nodes.enc"), []byte("data"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write encrypted file: %v", err)
|
||||
}
|
||||
|
||||
_, err := getOrCreateKeyAt(tmpDir)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when encrypted data exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_RandReaderError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
withRandReader(t, errReader{err: errors.New("read failed")})
|
||||
|
||||
_, err := getOrCreateKeyAt(tmpDir)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error from rand reader")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_CreateDirError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
|
||||
dataFile := filepath.Join(tmpDir, "datafile")
|
||||
if err := os.WriteFile(dataFile, []byte("x"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write data file: %v", err)
|
||||
}
|
||||
|
||||
_, err := getOrCreateKeyAt(dataFile)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when creating directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrCreateKeyAt_SaveKeyError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key"))
|
||||
|
||||
keyPath := filepath.Join(tmpDir, ".encryption.key")
|
||||
if err := os.MkdirAll(keyPath, 0700); err != nil {
|
||||
t.Fatalf("Failed to create key path dir: %v", err)
|
||||
}
|
||||
|
||||
_, err := getOrCreateKeyAt(tmpDir)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when saving key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptStringInvalidBase64(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cm, err := NewCryptoManagerAt(tmpDir)
|
||||
@@ -216,6 +494,57 @@ func TestEncryptionUniqueness(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptInvalidKey(t *testing.T) {
|
||||
cm := &CryptoManager{key: []byte("short")}
|
||||
if _, err := cm.Encrypt([]byte("data")); err == nil {
|
||||
t.Fatal("Expected error for invalid key length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidKey(t *testing.T) {
|
||||
cm := &CryptoManager{key: []byte("short")}
|
||||
if _, err := cm.Decrypt([]byte("data")); err == nil {
|
||||
t.Fatal("Expected error for invalid key length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptNonceReadError(t *testing.T) {
|
||||
withRandReader(t, errReader{err: errors.New("nonce read error")})
|
||||
cm := &CryptoManager{key: make([]byte, 32)}
|
||||
if _, err := cm.Encrypt([]byte("data")); err == nil {
|
||||
t.Fatal("Expected error reading nonce")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptGCMError(t *testing.T) {
|
||||
withNewGCM(t, func(cipher.Block) (cipher.AEAD, error) {
|
||||
return nil, errors.New("gcm error")
|
||||
})
|
||||
|
||||
cm := &CryptoManager{key: make([]byte, 32)}
|
||||
if _, err := cm.Encrypt([]byte("data")); err == nil {
|
||||
t.Fatal("Expected Encrypt error from GCM")
|
||||
}
|
||||
if _, err := cm.Decrypt([]byte("data")); err == nil {
|
||||
t.Fatal("Expected Decrypt error from GCM")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptStringError(t *testing.T) {
|
||||
cm := &CryptoManager{key: []byte("short")}
|
||||
if _, err := cm.EncryptString("data"); err == nil {
|
||||
t.Fatal("Expected EncryptString error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptStringError(t *testing.T) {
|
||||
cm := &CryptoManager{key: make([]byte, 32)}
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("short"))
|
||||
if _, err := cm.DecryptString(encoded); err == nil {
|
||||
t.Fatal("Expected DecryptString error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCryptoManagerRefusesOrphanedData(t *testing.T) {
|
||||
// Skip if production key exists - migration code will always find and use it
|
||||
if _, err := os.Stat("/etc/pulse/.encryption.key"); err == nil {
|
||||
|
||||
@@ -44,6 +44,28 @@ var (
|
||||
defaultTimeFmt = time.RFC3339
|
||||
)
|
||||
|
||||
var (
|
||||
nowFn = time.Now
|
||||
isTerminalFn = term.IsTerminal
|
||||
mkdirAllFn = os.MkdirAll
|
||||
openFileFn = os.OpenFile
|
||||
openFn = os.Open
|
||||
statFn = os.Stat
|
||||
readDirFn = os.ReadDir
|
||||
renameFn = os.Rename
|
||||
removeFn = os.Remove
|
||||
copyFn = io.Copy
|
||||
gzipNewWriterFn = gzip.NewWriter
|
||||
statFileFn = func(file *os.File) (os.FileInfo, error) { return file.Stat() }
|
||||
closeFileFn = func(file *os.File) error { return file.Close() }
|
||||
compressFn = compressAndRemove
|
||||
)
|
||||
|
||||
var (
|
||||
defaultStatFileFn = statFileFn
|
||||
defaultCloseFileFn = closeFileFn
|
||||
)
|
||||
|
||||
func init() {
|
||||
baseLogger = zerolog.New(baseWriter).With().Timestamp().Logger()
|
||||
log.Logger = baseLogger
|
||||
@@ -137,7 +159,7 @@ func isTerminal(file *os.File) bool {
|
||||
if file == nil {
|
||||
return false
|
||||
}
|
||||
return term.IsTerminal(int(file.Fd()))
|
||||
return isTerminalFn(int(file.Fd()))
|
||||
}
|
||||
|
||||
type rollingFileWriter struct {
|
||||
@@ -157,7 +179,7 @@ func newRollingFileWriter(cfg Config) (io.Writer, error) {
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
if err := mkdirAllFn(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create log directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -205,13 +227,13 @@ func (w *rollingFileWriter) openOrCreateLocked() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(w.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
|
||||
file, err := openFileFn(w.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
w.file = file
|
||||
|
||||
info, err := file.Stat()
|
||||
info, err := statFileFn(file)
|
||||
if err != nil {
|
||||
w.currentSize = 0
|
||||
return nil
|
||||
@@ -225,11 +247,11 @@ func (w *rollingFileWriter) rotateLocked() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(w.path); err == nil {
|
||||
rotated := fmt.Sprintf("%s.%s", w.path, time.Now().Format("20060102-150405"))
|
||||
if err := os.Rename(w.path, rotated); err == nil {
|
||||
if _, err := statFn(w.path); err == nil {
|
||||
rotated := fmt.Sprintf("%s.%s", w.path, nowFn().Format("20060102-150405"))
|
||||
if err := renameFn(w.path, rotated); err == nil {
|
||||
if w.compress {
|
||||
go compressAndRemove(rotated)
|
||||
go compressFn(rotated)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -242,7 +264,7 @@ func (w *rollingFileWriter) closeLocked() error {
|
||||
if w.file == nil {
|
||||
return nil
|
||||
}
|
||||
err := w.file.Close()
|
||||
err := closeFileFn(w.file)
|
||||
w.file = nil
|
||||
w.currentSize = 0
|
||||
return err
|
||||
@@ -256,9 +278,9 @@ func (w *rollingFileWriter) cleanupOldFiles() {
|
||||
dir := filepath.Dir(w.path)
|
||||
base := filepath.Base(w.path)
|
||||
prefix := base + "."
|
||||
cutoff := time.Now().Add(-w.maxAge)
|
||||
cutoff := nowFn().Add(-w.maxAge)
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
entries, err := readDirFn(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -273,26 +295,26 @@ func (w *rollingFileWriter) cleanupOldFiles() {
|
||||
continue
|
||||
}
|
||||
if info.ModTime().Before(cutoff) {
|
||||
_ = os.Remove(filepath.Join(dir, name))
|
||||
_ = removeFn(filepath.Join(dir, name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func compressAndRemove(path string) {
|
||||
in, err := os.Open(path)
|
||||
in, err := openFn(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
outPath := path + ".gz"
|
||||
out, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||
out, err := openFileFn(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
gw := gzip.NewWriter(out)
|
||||
if _, err = io.Copy(gw, in); err != nil {
|
||||
gw := gzipNewWriterFn(out)
|
||||
if _, err = copyFn(gw, in); err != nil {
|
||||
gw.Close()
|
||||
out.Close()
|
||||
return
|
||||
@@ -302,5 +324,5 @@ func compressAndRemove(path string) {
|
||||
return
|
||||
}
|
||||
out.Close()
|
||||
_ = os.Remove(path)
|
||||
_ = removeFn(path)
|
||||
}
|
||||
|
||||
@@ -2,13 +2,19 @@ package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func resetLoggingState() {
|
||||
@@ -21,6 +27,20 @@ func resetLoggingState() {
|
||||
log.Logger = baseLogger
|
||||
zerolog.TimeFieldFormat = defaultTimeFmt
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
nowFn = time.Now
|
||||
isTerminalFn = term.IsTerminal
|
||||
mkdirAllFn = os.MkdirAll
|
||||
openFileFn = os.OpenFile
|
||||
openFn = os.Open
|
||||
statFn = os.Stat
|
||||
readDirFn = os.ReadDir
|
||||
renameFn = os.Rename
|
||||
removeFn = os.Remove
|
||||
copyFn = io.Copy
|
||||
gzipNewWriterFn = gzip.NewWriter
|
||||
statFileFn = defaultStatFileFn
|
||||
closeFileFn = defaultCloseFileFn
|
||||
compressFn = compressAndRemove
|
||||
}
|
||||
|
||||
func TestInitJSONFormatSetsLevelAndComponent(t *testing.T) {
|
||||
@@ -267,6 +287,393 @@ func TestSelectWriter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWriterAutoTerminal(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
isTerminalFn = func(int) bool { return true }
|
||||
|
||||
w := selectWriter("auto")
|
||||
if _, ok := w.(zerolog.ConsoleWriter); !ok {
|
||||
t.Fatalf("expected console writer, got %#v", w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWriterDefault(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
w := selectWriter("unknown")
|
||||
if w != os.Stderr {
|
||||
t.Fatalf("expected default writer to be os.Stderr, got %#v", w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTerminalNil(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
if isTerminal(nil) {
|
||||
t.Fatal("expected nil file to report false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRollingFileWriter_EmptyPath(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
writer, err := newRollingFileWriter(Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if writer != nil {
|
||||
t.Fatalf("expected nil writer, got %#v", writer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRollingFileWriter_MkdirError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
mkdirAllFn = func(string, os.FileMode) error {
|
||||
return errors.New("mkdir failed")
|
||||
}
|
||||
|
||||
_, err := newRollingFileWriter(Config{FilePath: "/tmp/logs/test.log"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from mkdir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitFileWriterError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
mkdirAllFn = func(string, os.FileMode) error {
|
||||
return errors.New("mkdir failed")
|
||||
}
|
||||
|
||||
Init(Config{
|
||||
Format: "json",
|
||||
FilePath: "/tmp/logs/test.log",
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewRollingFileWriter_DefaultMaxSize(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
dir := t.TempDir()
|
||||
writer, err := newRollingFileWriter(Config{
|
||||
FilePath: filepath.Join(dir, "app.log"),
|
||||
MaxSizeMB: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
w, ok := writer.(*rollingFileWriter)
|
||||
if !ok {
|
||||
t.Fatalf("expected rollingFileWriter, got %#v", writer)
|
||||
}
|
||||
if w.maxBytes != 100*1024*1024 {
|
||||
t.Fatalf("expected default max bytes, got %d", w.maxBytes)
|
||||
}
|
||||
_ = w.closeLocked()
|
||||
}
|
||||
|
||||
func TestNewRollingFileWriter_OpenError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
_, err := newRollingFileWriter(Config{FilePath: filepath.Join(t.TempDir(), "app.log")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from openOrCreateLocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenOrCreateLocked_StatError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
dir := t.TempDir()
|
||||
w := &rollingFileWriter{path: filepath.Join(dir, "app.log")}
|
||||
statFileFn = func(*os.File) (os.FileInfo, error) {
|
||||
return nil, errors.New("stat failed")
|
||||
}
|
||||
if err := w.openOrCreateLocked(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if w.currentSize != 0 {
|
||||
t.Fatalf("expected current size 0, got %d", w.currentSize)
|
||||
}
|
||||
_ = w.closeLocked()
|
||||
}
|
||||
|
||||
func TestOpenOrCreateLocked_AlreadyOpen(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
file, err := os.CreateTemp(t.TempDir(), "log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
w := &rollingFileWriter{path: file.Name(), file: file}
|
||||
if err := w.openOrCreateLocked(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_ = w.closeLocked()
|
||||
}
|
||||
|
||||
func TestRollingFileWriter_WriteOpenError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log")}
|
||||
if _, err := w.Write([]byte("data")); err == nil {
|
||||
t.Fatal("expected write error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollingFileWriter_WriteRotateError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
callCount := 0
|
||||
openFileFn = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
w := &rollingFileWriter{path: path, maxBytes: 1}
|
||||
if _, err := w.Write([]byte("too big")); err == nil {
|
||||
t.Fatal("expected rotate error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollingFileWriter_RotateCompress(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
w := &rollingFileWriter{path: path, maxBytes: 1, compress: true}
|
||||
if err := w.openOrCreateLocked(); err != nil {
|
||||
t.Fatalf("openOrCreateLocked error: %v", err)
|
||||
}
|
||||
|
||||
ch := make(chan string, 1)
|
||||
compressFn = func(p string) { ch <- p }
|
||||
|
||||
if err := w.rotateLocked(); err != nil {
|
||||
t.Fatalf("rotateLocked error: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected compress to be triggered")
|
||||
}
|
||||
_ = w.closeLocked()
|
||||
}
|
||||
|
||||
func TestRotateLockedCloseError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open file: %v", err)
|
||||
}
|
||||
w := &rollingFileWriter{path: path, file: file}
|
||||
closeFileFn = func(*os.File) error {
|
||||
return errors.New("close failed")
|
||||
}
|
||||
|
||||
if err := w.rotateLocked(); err == nil {
|
||||
t.Fatal("expected close error")
|
||||
}
|
||||
_ = file.Close()
|
||||
}
|
||||
|
||||
func TestCloseLocked(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
w := &rollingFileWriter{}
|
||||
if err := w.closeLocked(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.CreateTemp(t.TempDir(), "log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
w.file = file
|
||||
if err := w.closeLocked(); err != nil {
|
||||
t.Fatalf("unexpected close error: %v", err)
|
||||
}
|
||||
if w.file != nil {
|
||||
t.Fatal("expected file to be cleared")
|
||||
}
|
||||
if w.currentSize != 0 {
|
||||
t.Fatalf("expected size reset, got %d", w.currentSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupOldFilesNoMaxAge(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: 0}
|
||||
w.cleanupOldFiles()
|
||||
}
|
||||
|
||||
func TestCleanupOldFilesReadDirError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
readDirFn = func(string) ([]os.DirEntry, error) {
|
||||
return nil, errors.New("read dir failed")
|
||||
}
|
||||
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: time.Hour}
|
||||
w.cleanupOldFiles()
|
||||
}
|
||||
|
||||
func TestCleanupOldFilesInfoError(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
readDirFn = func(string) ([]os.DirEntry, error) {
|
||||
return []os.DirEntry{errDirEntry{name: "app.log.20200101"}}, nil
|
||||
}
|
||||
w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: time.Hour}
|
||||
w.cleanupOldFiles()
|
||||
}
|
||||
|
||||
func TestCleanupOldFilesRemovesOld(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
oldFile := filepath.Join(dir, "app.log.20200101-000000")
|
||||
newFile := filepath.Join(dir, "app.log.20250101-000000")
|
||||
otherFile := filepath.Join(dir, "other.log.20200101")
|
||||
|
||||
if err := os.WriteFile(oldFile, []byte("old"), 0600); err != nil {
|
||||
t.Fatalf("failed to write old file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(newFile, []byte("new"), 0600); err != nil {
|
||||
t.Fatalf("failed to write new file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(otherFile, []byte("other"), 0600); err != nil {
|
||||
t.Fatalf("failed to write other file: %v", err)
|
||||
}
|
||||
|
||||
fixedNow := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
nowFn = func() time.Time { return fixedNow }
|
||||
|
||||
if err := os.Chtimes(oldFile, fixedNow.Add(-48*time.Hour), fixedNow.Add(-48*time.Hour)); err != nil {
|
||||
t.Fatalf("failed to set old file time: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(newFile, fixedNow.Add(-time.Hour), fixedNow.Add(-time.Hour)); err != nil {
|
||||
t.Fatalf("failed to set new file time: %v", err)
|
||||
}
|
||||
|
||||
w := &rollingFileWriter{path: path, maxAge: 24 * time.Hour}
|
||||
w.cleanupOldFiles()
|
||||
|
||||
if _, err := os.Stat(oldFile); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected old file to be removed")
|
||||
}
|
||||
if _, err := os.Stat(newFile); err != nil {
|
||||
t.Fatalf("expected new file to remain: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(otherFile); err != nil {
|
||||
t.Fatalf("expected other file to remain: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatFileFnDefault(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
|
||||
file, err := os.CreateTemp(t.TempDir(), "log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = file.Close() })
|
||||
|
||||
if _, err := statFileFn(file); err != nil {
|
||||
t.Fatalf("statFileFn error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressAndRemove(t *testing.T) {
|
||||
t.Run("OpenError", func(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
openFn = func(string) (*os.File, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
compressAndRemove("/does/not/exist")
|
||||
})
|
||||
|
||||
t.Run("OpenFileError", func(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("open file failed")
|
||||
}
|
||||
compressAndRemove(path)
|
||||
})
|
||||
|
||||
t.Run("CopyError", func(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
copyFn = func(io.Writer, io.Reader) (int64, error) {
|
||||
return 0, errors.New("copy failed")
|
||||
}
|
||||
compressAndRemove(path)
|
||||
})
|
||||
|
||||
t.Run("CloseError", func(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
errWriter := errWriteCloser{err: errors.New("write failed")}
|
||||
gzipNewWriterFn = func(io.Writer) *gzip.Writer {
|
||||
return gzip.NewWriter(errWriter)
|
||||
}
|
||||
copyFn = func(io.Writer, io.Reader) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
compressAndRemove(path)
|
||||
})
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "app.log")
|
||||
if err := os.WriteFile(path, []byte("data"), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
compressAndRemove(path)
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Fatal("expected original file to be removed")
|
||||
}
|
||||
if _, err := os.Stat(path + ".gz"); err != nil {
|
||||
t.Fatalf("expected gzip file to exist: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test that the logging package doesn't panic under concurrent use
|
||||
func TestConcurrentLogging(t *testing.T) {
|
||||
t.Cleanup(resetLoggingState)
|
||||
@@ -292,3 +699,24 @@ func TestConcurrentLogging(t *testing.T) {
|
||||
t.Fatal("expected log output from concurrent logging")
|
||||
}
|
||||
}
|
||||
|
||||
type errDirEntry struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (e errDirEntry) Name() string { return e.name }
|
||||
func (e errDirEntry) IsDir() bool { return false }
|
||||
func (e errDirEntry) Type() os.FileMode {
|
||||
return 0
|
||||
}
|
||||
func (e errDirEntry) Info() (os.FileInfo, error) {
|
||||
return nil, errors.New("info error")
|
||||
}
|
||||
|
||||
type errWriteCloser struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errWriteCloser) Write(p []byte) (int, error) {
|
||||
return 0, e.err
|
||||
}
|
||||
|
||||
@@ -14,9 +14,13 @@ import (
|
||||
|
||||
// Pre-compiled regexes for performance (avoid recompilation on each call)
|
||||
var (
|
||||
mdDeviceRe = regexp.MustCompile(`^(md\d+)\s*:`)
|
||||
slotRe = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(.+?)\s+(/dev/.+)$`)
|
||||
speedRe = regexp.MustCompile(`speed=(\S+)`)
|
||||
mdDeviceRe = regexp.MustCompile(`^(md\d+)\s*:`)
|
||||
slotRe = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(.+?)\s+(/dev/.+)$`)
|
||||
speedRe = regexp.MustCompile(`speed=(\S+)`)
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
return cmd.Output()
|
||||
}
|
||||
)
|
||||
|
||||
// CollectArrays discovers and collects status for all mdadm RAID arrays on the system.
|
||||
@@ -56,8 +60,8 @@ func isMdadmAvailable(ctx context.Context) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mdadm", "--version")
|
||||
return cmd.Run() == nil
|
||||
_, err := runCommandOutput(ctx, "mdadm", "--version")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// listArrayDevices scans /proc/mdstat to find all md devices
|
||||
@@ -65,8 +69,7 @@ func listArrayDevices(ctx context.Context) ([]string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "cat", "/proc/mdstat")
|
||||
output, err := cmd.Output()
|
||||
output, err := runCommandOutput(ctx, "cat", "/proc/mdstat")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read /proc/mdstat: %w", err)
|
||||
}
|
||||
@@ -89,8 +92,7 @@ func collectArrayDetail(ctx context.Context, device string) (host.RAIDArray, err
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mdadm", "--detail", device)
|
||||
output, err := cmd.Output()
|
||||
output, err := runCommandOutput(ctx, "mdadm", "--detail", device)
|
||||
if err != nil {
|
||||
return host.RAIDArray{}, fmt.Errorf("mdadm --detail %s: %w", device, err)
|
||||
}
|
||||
@@ -220,8 +222,7 @@ func getRebuildSpeed(device string) string {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "cat", "/proc/mdstat")
|
||||
output, err := cmd.Output()
|
||||
output, err := runCommandOutput(ctx, "cat", "/proc/mdstat")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
package mdadm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/pkg/agents/host"
|
||||
)
|
||||
|
||||
func withRunCommandOutput(t *testing.T, fn func(ctx context.Context, name string, args ...string) ([]byte, error)) {
|
||||
t.Helper()
|
||||
orig := runCommandOutput
|
||||
runCommandOutput = fn
|
||||
t.Cleanup(func() { runCommandOutput = orig })
|
||||
}
|
||||
|
||||
func TestParseDetail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -540,3 +549,239 @@ Consistency Policy : resync
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMdadmAvailable(t *testing.T) {
|
||||
t.Run("available", func(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte("mdadm"), nil
|
||||
})
|
||||
|
||||
if !isMdadmAvailable(context.Background()) {
|
||||
t.Fatal("expected mdadm available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("missing")
|
||||
})
|
||||
|
||||
if isMdadmAvailable(context.Background()) {
|
||||
t.Fatal("expected mdadm unavailable")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListArrayDevices(t *testing.T) {
|
||||
mdstat := `Personalities : [raid1] [raid6]
|
||||
md0 : active raid1 sdb1[1] sda1[0]
|
||||
md1 : active raid6 sdc1[2] sdb1[1] sda1[0]
|
||||
unused devices: <none>`
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte(mdstat), nil
|
||||
})
|
||||
|
||||
devices, err := listArrayDevices(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("listArrayDevices error: %v", err)
|
||||
}
|
||||
if len(devices) != 2 || devices[0] != "/dev/md0" || devices[1] != "/dev/md1" {
|
||||
t.Fatalf("unexpected devices: %v", devices)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListArrayDevicesError(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("read failed")
|
||||
})
|
||||
|
||||
if _, err := listArrayDevices(context.Background()); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectArrayDetailError(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("detail failed")
|
||||
})
|
||||
|
||||
if _, err := collectArrayDetail(context.Background(), "/dev/md0"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectArraysNotAvailable(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("missing")
|
||||
})
|
||||
|
||||
arrays, err := CollectArrays(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if arrays != nil {
|
||||
t.Fatalf("expected nil arrays, got %v", arrays)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectArraysListError(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name == "mdadm" {
|
||||
return []byte("mdadm"), nil
|
||||
}
|
||||
return nil, errors.New("read failed")
|
||||
})
|
||||
|
||||
if _, err := CollectArrays(context.Background()); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectArraysNoDevices(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name == "mdadm" {
|
||||
return []byte("mdadm"), nil
|
||||
}
|
||||
return []byte("unused devices: <none>"), nil
|
||||
})
|
||||
|
||||
arrays, err := CollectArrays(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if arrays != nil {
|
||||
t.Fatalf("expected nil arrays, got %v", arrays)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectArraysSkipsDetailError(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
switch name {
|
||||
case "mdadm":
|
||||
if len(args) > 0 && args[0] == "--version" {
|
||||
return []byte("mdadm"), nil
|
||||
}
|
||||
return nil, errors.New("detail failed")
|
||||
case "cat":
|
||||
return []byte("md0 : active raid1 sda1[0]"), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected")
|
||||
}
|
||||
})
|
||||
|
||||
arrays, err := CollectArrays(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(arrays) != 0 {
|
||||
t.Fatalf("expected empty arrays, got %v", arrays)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectArraysSuccess(t *testing.T) {
|
||||
detail := `/dev/md0:
|
||||
Raid Level : raid1
|
||||
State : clean
|
||||
Total Devices : 2
|
||||
Active Devices : 2
|
||||
Working Devices : 2
|
||||
Failed Devices : 0
|
||||
Spare Devices : 0
|
||||
|
||||
Number Major Minor RaidDevice State
|
||||
0 8 1 0 active sync /dev/sda1`
|
||||
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
switch name {
|
||||
case "mdadm":
|
||||
if len(args) > 0 && args[0] == "--version" {
|
||||
return []byte("mdadm"), nil
|
||||
}
|
||||
return []byte(detail), nil
|
||||
case "cat":
|
||||
return []byte("md0 : active raid1 sda1[0]"), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected")
|
||||
}
|
||||
})
|
||||
|
||||
arrays, err := CollectArrays(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(arrays) != 1 || arrays[0].Device != "/dev/md0" {
|
||||
t.Fatalf("unexpected arrays: %v", arrays)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRebuildSpeed(t *testing.T) {
|
||||
mdstat := `md0 : active raid1 sda1[0] sdb1[1]
|
||||
[>....................] recovery = 12.6% (37043392/293039104) finish=127.5min speed=33440K/sec
|
||||
`
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte(mdstat), nil
|
||||
})
|
||||
|
||||
if speed := getRebuildSpeed("/dev/md0"); speed != "33440K/sec" {
|
||||
t.Fatalf("unexpected speed: %s", speed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRebuildSpeedNoMatch(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte("md0 : active raid1 sda1[0]"), nil
|
||||
})
|
||||
|
||||
if speed := getRebuildSpeed("/dev/md0"); speed != "" {
|
||||
t.Fatalf("expected empty speed, got %s", speed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRebuildSpeedError(t *testing.T) {
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("read failed")
|
||||
})
|
||||
|
||||
if speed := getRebuildSpeed("/dev/md0"); speed != "" {
|
||||
t.Fatalf("expected empty speed, got %s", speed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDetailSetsRebuildSpeed(t *testing.T) {
|
||||
output := `/dev/md0:
|
||||
Raid Level : raid1
|
||||
State : clean
|
||||
Rebuild Status : 12% complete
|
||||
|
||||
Number Major Minor RaidDevice State
|
||||
0 8 1 0 active sync /dev/sda1`
|
||||
|
||||
mdstat := `md0 : active raid1 sda1[0]
|
||||
[>....................] recovery = 12.6% (37043392/293039104) finish=127.5min speed=1234K/sec
|
||||
`
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte(mdstat), nil
|
||||
})
|
||||
|
||||
array, err := parseDetail("/dev/md0", output)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDetail error: %v", err)
|
||||
}
|
||||
if array.RebuildSpeed != "1234K/sec" {
|
||||
t.Fatalf("expected rebuild speed, got %s", array.RebuildSpeed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRebuildSpeedSectionExit(t *testing.T) {
|
||||
mdstat := `md0 : active raid1 sda1[0]
|
||||
[>....................] recovery = 12.6% (37043392/293039104) finish=127.5min
|
||||
md1 : active raid1 sdb1[0]
|
||||
`
|
||||
withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte(mdstat), nil
|
||||
})
|
||||
|
||||
if speed := getRebuildSpeed("/dev/md0"); speed != "" {
|
||||
t.Fatalf("expected empty speed, got %s", speed)
|
||||
}
|
||||
}
|
||||
|
||||
578
internal/resources/converters_coverage_test.go
Normal file
578
internal/resources/converters_coverage_test.go
Normal file
@@ -0,0 +1,578 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
)
|
||||
|
||||
func TestFromNodeTemperatureAndCluster(t *testing.T) {
|
||||
now := time.Now()
|
||||
node := models.Node{
|
||||
ID: "node-1",
|
||||
Name: "node-1",
|
||||
Instance: "pve1",
|
||||
Status: "online",
|
||||
CPU: 0.5,
|
||||
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
|
||||
Disk: models.Disk{Total: 200, Used: 100, Free: 100, Usage: 50},
|
||||
Uptime: 10,
|
||||
LastSeen: now,
|
||||
IsClusterMember: true,
|
||||
ClusterName: "cluster-a",
|
||||
Temperature: &models.Temperature{
|
||||
Available: true,
|
||||
HasCPU: true,
|
||||
CPUPackage: 72.5,
|
||||
},
|
||||
}
|
||||
r := FromNode(node)
|
||||
if r.Temperature == nil || *r.Temperature != 72.5 {
|
||||
t.Fatalf("expected CPU package temperature")
|
||||
}
|
||||
if r.ClusterID != "pve-cluster/cluster-a" {
|
||||
t.Fatalf("expected cluster ID")
|
||||
}
|
||||
|
||||
node2 := node
|
||||
node2.ID = "node-2"
|
||||
node2.ClusterName = ""
|
||||
node2.IsClusterMember = false
|
||||
node2.Temperature = &models.Temperature{
|
||||
Available: true,
|
||||
HasCPU: true,
|
||||
Cores: []models.CoreTemp{
|
||||
{Core: 0, Temp: 60},
|
||||
{Core: 1, Temp: 80},
|
||||
},
|
||||
}
|
||||
r2 := FromNode(node2)
|
||||
if r2.Temperature == nil || *r2.Temperature != 70 {
|
||||
t.Fatalf("expected averaged core temperature")
|
||||
}
|
||||
if r2.ClusterID != "" {
|
||||
t.Fatalf("expected empty cluster ID when not a member")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromHostAndDockerHost(t *testing.T) {
|
||||
now := time.Now()
|
||||
host := models.Host{
|
||||
ID: "host-1",
|
||||
Hostname: "host-1",
|
||||
Status: "degraded",
|
||||
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
|
||||
Disks: []models.Disk{
|
||||
{Total: 100, Used: 60, Free: 40, Usage: 60},
|
||||
},
|
||||
NetworkInterfaces: []models.HostNetworkInterface{
|
||||
{Addresses: []string{"10.0.0.1"}, RXBytes: 10, TXBytes: 20},
|
||||
},
|
||||
Sensors: models.HostSensorSummary{
|
||||
TemperatureCelsius: map[string]float64{"cpu": 55},
|
||||
},
|
||||
LastSeen: now,
|
||||
}
|
||||
hr := FromHost(host)
|
||||
if hr.Temperature == nil || *hr.Temperature != 55 {
|
||||
t.Fatalf("expected host temperature")
|
||||
}
|
||||
if hr.Network == nil || hr.Network.RXBytes != 10 || hr.Network.TXBytes != 20 {
|
||||
t.Fatalf("expected host network totals")
|
||||
}
|
||||
if hr.Disk == nil || hr.Disk.Current == 0 {
|
||||
t.Fatalf("expected host disk metrics")
|
||||
}
|
||||
if hr.Status != StatusDegraded {
|
||||
t.Fatalf("expected degraded host status")
|
||||
}
|
||||
|
||||
dockerHost := models.DockerHost{
|
||||
ID: "docker-1",
|
||||
AgentID: "agent-1",
|
||||
Hostname: "docker-1",
|
||||
DisplayName: "docker-1",
|
||||
CustomDisplayName: "custom",
|
||||
Status: "offline",
|
||||
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
|
||||
Disks: []models.Disk{
|
||||
{Total: 100, Used: 50, Free: 50, Usage: 50},
|
||||
},
|
||||
NetworkInterfaces: []models.HostNetworkInterface{
|
||||
{Addresses: []string{"10.0.0.2"}, RXBytes: 1, TXBytes: 2},
|
||||
},
|
||||
Swarm: &models.DockerSwarmInfo{
|
||||
ClusterID: "swarm-1",
|
||||
},
|
||||
LastSeen: now,
|
||||
}
|
||||
dr := FromDockerHost(dockerHost)
|
||||
if dr.DisplayName != "custom" {
|
||||
t.Fatalf("expected custom display name")
|
||||
}
|
||||
if dr.ClusterID != "docker-swarm/swarm-1" {
|
||||
t.Fatalf("expected docker swarm cluster ID")
|
||||
}
|
||||
if dr.Status != StatusOffline {
|
||||
t.Fatalf("expected docker host status offline")
|
||||
}
|
||||
if dr.Identity == nil || len(dr.Identity.IPs) != 1 || dr.Identity.IPs[0] != "10.0.0.2" {
|
||||
t.Fatalf("expected docker host identity IPs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromHostPlatformData(t *testing.T) {
|
||||
now := time.Now()
|
||||
host := models.Host{
|
||||
ID: "host-2",
|
||||
Hostname: "host-2",
|
||||
Status: "online",
|
||||
DiskIO: []models.DiskIO{
|
||||
{Device: "sda", ReadBytes: 100, WriteBytes: 200, ReadOps: 1, WriteOps: 2, ReadTime: 3, WriteTime: 4, IOTime: 5},
|
||||
},
|
||||
RAID: []models.HostRAIDArray{
|
||||
{
|
||||
Device: "md0",
|
||||
Level: "raid1",
|
||||
State: "clean",
|
||||
TotalDevices: 2,
|
||||
ActiveDevices: 2,
|
||||
WorkingDevices: 2,
|
||||
FailedDevices: 0,
|
||||
SpareDevices: 0,
|
||||
Devices: []models.HostRAIDDevice{
|
||||
{Device: "sda1", State: "active", Slot: 0},
|
||||
},
|
||||
RebuildPercent: 0,
|
||||
},
|
||||
},
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
r := FromHost(host)
|
||||
var pd HostPlatformData
|
||||
if err := r.GetPlatformData(&pd); err != nil {
|
||||
t.Fatalf("failed to get platform data: %v", err)
|
||||
}
|
||||
if len(pd.DiskIO) != 1 || len(pd.RAID) != 1 {
|
||||
t.Fatalf("expected disk IO and RAID entries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromDockerContainerPodman(t *testing.T) {
|
||||
now := time.Now()
|
||||
container := models.DockerContainer{
|
||||
ID: "container-1",
|
||||
Name: "app",
|
||||
Image: "app:latest",
|
||||
State: "paused",
|
||||
Status: "Paused",
|
||||
CPUPercent: 5,
|
||||
MemoryUsage: 50,
|
||||
MemoryLimit: 100,
|
||||
MemoryPercent: 50,
|
||||
UptimeSeconds: 10,
|
||||
Ports: []models.DockerContainerPort{
|
||||
{PrivatePort: 80, PublicPort: 8080, Protocol: "tcp", IP: "0.0.0.0"},
|
||||
},
|
||||
Networks: []models.DockerContainerNetworkLink{
|
||||
{Name: "bridge", IPv4: "172.17.0.2"},
|
||||
},
|
||||
Podman: &models.DockerPodmanContainer{
|
||||
PodName: "pod",
|
||||
},
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
r := FromDockerContainer(container, "host-1", "host-1")
|
||||
if r.Status != StatusPaused {
|
||||
t.Fatalf("expected paused status")
|
||||
}
|
||||
if r.Memory == nil || r.Memory.Current != 50 {
|
||||
t.Fatalf("expected container memory metrics")
|
||||
}
|
||||
if r.ID != "host-1/container-1" {
|
||||
t.Fatalf("expected container resource ID to be host-1/container-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromVMWithBackup(t *testing.T) {
|
||||
now := time.Now()
|
||||
vm := models.VM{
|
||||
ID: "vm-1",
|
||||
VMID: 100,
|
||||
Name: "vm-1",
|
||||
Node: "node-1",
|
||||
Instance: "pve1",
|
||||
Status: "running",
|
||||
CPU: 0.25,
|
||||
LastBackup: now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
r := FromVM(vm)
|
||||
var pd VMPlatformData
|
||||
if err := r.GetPlatformData(&pd); err != nil {
|
||||
t.Fatalf("failed to get platform data: %v", err)
|
||||
}
|
||||
if pd.LastBackup == nil || !pd.LastBackup.Equal(now) {
|
||||
t.Fatalf("expected last backup to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromContainerOCI(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct := models.Container{
|
||||
ID: "ct-1",
|
||||
VMID: 200,
|
||||
Name: "ct-1",
|
||||
Node: "node-1",
|
||||
Instance: "pve1",
|
||||
Status: "paused",
|
||||
Type: "lxc",
|
||||
IsOCI: true,
|
||||
CPU: 0.1,
|
||||
Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50},
|
||||
Disk: models.Disk{Total: 200, Used: 100, Free: 100, Usage: 50},
|
||||
LastBackup: now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
r := FromContainer(ct)
|
||||
if r.Type != ResourceTypeOCIContainer {
|
||||
t.Fatalf("expected OCI container type")
|
||||
}
|
||||
var pd ContainerPlatformData
|
||||
if err := r.GetPlatformData(&pd); err != nil {
|
||||
t.Fatalf("failed to get platform data: %v", err)
|
||||
}
|
||||
if !pd.IsOCI || pd.LastBackup == nil {
|
||||
t.Fatalf("expected OCI platform data with backup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubernetesConversions(t *testing.T) {
|
||||
now := time.Now()
|
||||
cluster := models.KubernetesCluster{
|
||||
ID: "cluster-1",
|
||||
AgentID: "agent-1",
|
||||
CustomDisplayName: "custom",
|
||||
Status: "online",
|
||||
LastSeen: now,
|
||||
Nodes: []models.KubernetesNode{
|
||||
{Name: "node-1", Ready: true, Unschedulable: true},
|
||||
},
|
||||
Pods: []models.KubernetesPod{
|
||||
{Name: "pod-1", Namespace: "default", Phase: "Succeeded", NodeName: "node-1"},
|
||||
},
|
||||
Deployments: []models.KubernetesDeployment{
|
||||
{Name: "dep-1", Namespace: "default", DesiredReplicas: 1, AvailableReplicas: 1},
|
||||
},
|
||||
}
|
||||
|
||||
cr := FromKubernetesCluster(cluster)
|
||||
if cr.Name != "custom" || cr.DisplayName != "custom" {
|
||||
t.Fatalf("expected custom display name to be used")
|
||||
}
|
||||
|
||||
node := cluster.Nodes[0]
|
||||
nr := FromKubernetesNode(node, cluster)
|
||||
if nr.Status != StatusDegraded {
|
||||
t.Fatalf("expected unschedulable node to be degraded")
|
||||
}
|
||||
if !strings.Contains(nr.ID, "cluster-1/node/") {
|
||||
t.Fatalf("expected node ID to include cluster and node")
|
||||
}
|
||||
|
||||
pod := cluster.Pods[0]
|
||||
pr := FromKubernetesPod(pod, cluster)
|
||||
if pr.Status != StatusStopped {
|
||||
t.Fatalf("expected succeeded pod to be stopped")
|
||||
}
|
||||
if !strings.Contains(pr.ParentID, "cluster-1/node/") {
|
||||
t.Fatalf("expected pod parent to be node")
|
||||
}
|
||||
|
||||
dep := cluster.Deployments[0]
|
||||
dr := FromKubernetesDeployment(dep, cluster)
|
||||
if dr.Status != StatusRunning {
|
||||
t.Fatalf("expected deployment running")
|
||||
}
|
||||
|
||||
emptyCluster := models.KubernetesCluster{
|
||||
ID: "cluster-2",
|
||||
AgentID: "agent-2",
|
||||
Status: "offline",
|
||||
LastSeen: now,
|
||||
}
|
||||
er := FromKubernetesCluster(emptyCluster)
|
||||
if er.Name != "cluster-2" || er.DisplayName != "cluster-2" {
|
||||
t.Fatalf("expected cluster name fallback to ID")
|
||||
}
|
||||
if er.Status != StatusOffline {
|
||||
t.Fatalf("expected offline cluster status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubernetesNodeNotReady(t *testing.T) {
|
||||
now := time.Now()
|
||||
cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now}
|
||||
node := models.KubernetesNode{
|
||||
UID: "node-uid",
|
||||
Name: "node-offline",
|
||||
Ready: false,
|
||||
}
|
||||
r := FromKubernetesNode(node, cluster)
|
||||
if r.Status != StatusOffline {
|
||||
t.Fatalf("expected offline node status")
|
||||
}
|
||||
if !strings.Contains(r.ID, "node-uid") {
|
||||
t.Fatalf("expected node UID in resource ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubernetesPodParentCluster(t *testing.T) {
|
||||
now := time.Now()
|
||||
cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now}
|
||||
pod := models.KubernetesPod{
|
||||
UID: "pod-uid",
|
||||
Name: "pod-1",
|
||||
Namespace: "default",
|
||||
Phase: "Pending",
|
||||
}
|
||||
|
||||
r := FromKubernetesPod(pod, cluster)
|
||||
if r.ParentID != cluster.ID {
|
||||
t.Fatalf("expected pod parent to be cluster ID")
|
||||
}
|
||||
if !strings.Contains(r.ID, "pod-uid") {
|
||||
t.Fatalf("expected pod UID in resource ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubernetesPodContainers(t *testing.T) {
|
||||
now := time.Now()
|
||||
cluster := models.KubernetesCluster{ID: "cluster-3", AgentID: "agent-3", LastSeen: now}
|
||||
pod := models.KubernetesPod{
|
||||
Name: "pod-2",
|
||||
Namespace: "default",
|
||||
NodeName: "node-2",
|
||||
Phase: "Running",
|
||||
Containers: []models.KubernetesPodContainer{
|
||||
{Name: "c1", Image: "busybox", Ready: true, RestartCount: 1, State: "running"},
|
||||
},
|
||||
}
|
||||
|
||||
r := FromKubernetesPod(pod, cluster)
|
||||
if r.Status != StatusRunning {
|
||||
t.Fatalf("expected running pod status")
|
||||
}
|
||||
if !strings.Contains(r.ID, "cluster-3/pod/") {
|
||||
t.Fatalf("expected pod ID to include cluster")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubernetesDeploymentStatuses(t *testing.T) {
|
||||
now := time.Now()
|
||||
cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now}
|
||||
|
||||
depStopped := FromKubernetesDeployment(models.KubernetesDeployment{
|
||||
Name: "dep-stop",
|
||||
Namespace: "default",
|
||||
DesiredReplicas: 0,
|
||||
}, cluster)
|
||||
if depStopped.Status != StatusStopped {
|
||||
t.Fatalf("expected stopped deployment")
|
||||
}
|
||||
|
||||
depDegraded := FromKubernetesDeployment(models.KubernetesDeployment{
|
||||
Name: "dep-degraded",
|
||||
Namespace: "default",
|
||||
DesiredReplicas: 3,
|
||||
AvailableReplicas: 1,
|
||||
}, cluster)
|
||||
if depDegraded.Status != StatusDegraded {
|
||||
t.Fatalf("expected degraded deployment")
|
||||
}
|
||||
|
||||
depUnknown := FromKubernetesDeployment(models.KubernetesDeployment{
|
||||
Name: "dep-unknown",
|
||||
Namespace: "default",
|
||||
DesiredReplicas: 2,
|
||||
AvailableReplicas: 0,
|
||||
}, cluster)
|
||||
if depUnknown.Status != StatusUnknown {
|
||||
t.Fatalf("expected unknown deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromPBSInstanceAndStorage(t *testing.T) {
|
||||
now := time.Now()
|
||||
pbs := models.PBSInstance{
|
||||
ID: "pbs-1",
|
||||
Name: "pbs",
|
||||
Host: "pbs.local",
|
||||
Status: "online",
|
||||
ConnectionHealth: "unhealthy",
|
||||
CPU: 20,
|
||||
Memory: 50,
|
||||
MemoryTotal: 100,
|
||||
MemoryUsed: 50,
|
||||
Uptime: 10,
|
||||
LastSeen: now,
|
||||
}
|
||||
pr := FromPBSInstance(pbs)
|
||||
if pr.Status != StatusDegraded {
|
||||
t.Fatalf("expected degraded status for unhealthy pbs")
|
||||
}
|
||||
|
||||
pbs2 := pbs
|
||||
pbs2.ID = "pbs-2"
|
||||
pbs2.ConnectionHealth = "healthy"
|
||||
pbs2.Status = "offline"
|
||||
pr2 := FromPBSInstance(pbs2)
|
||||
if pr2.Status != StatusOffline {
|
||||
t.Fatalf("expected offline status for pbs")
|
||||
}
|
||||
|
||||
storageOnline := FromStorage(models.Storage{
|
||||
ID: "storage-1",
|
||||
Name: "local",
|
||||
Instance: "pve1",
|
||||
Node: "node1",
|
||||
Total: 100,
|
||||
Used: 50,
|
||||
Free: 50,
|
||||
Usage: 50,
|
||||
Active: true,
|
||||
Enabled: true,
|
||||
})
|
||||
if storageOnline.Status != StatusOnline {
|
||||
t.Fatalf("expected online storage")
|
||||
}
|
||||
|
||||
storageStopped := FromStorage(models.Storage{
|
||||
ID: "storage-2",
|
||||
Name: "local",
|
||||
Instance: "pve1",
|
||||
Node: "node1",
|
||||
Active: true,
|
||||
Enabled: false,
|
||||
})
|
||||
if storageStopped.Status != StatusStopped {
|
||||
t.Fatalf("expected stopped storage")
|
||||
}
|
||||
|
||||
storageOffline := FromStorage(models.Storage{
|
||||
ID: "storage-3",
|
||||
Name: "local",
|
||||
Instance: "pve1",
|
||||
Node: "node1",
|
||||
Active: false,
|
||||
Enabled: true,
|
||||
})
|
||||
if storageOffline.Status != StatusOffline {
|
||||
t.Fatalf("expected offline storage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMappings(t *testing.T) {
|
||||
if mapGuestStatus("running") != StatusRunning {
|
||||
t.Fatalf("expected running guest status")
|
||||
}
|
||||
if mapGuestStatus("stopped") != StatusStopped {
|
||||
t.Fatalf("expected stopped guest status")
|
||||
}
|
||||
if mapGuestStatus("paused") != StatusPaused {
|
||||
t.Fatalf("expected paused guest status")
|
||||
}
|
||||
if mapGuestStatus("unknown") != StatusUnknown {
|
||||
t.Fatalf("expected unknown guest status")
|
||||
}
|
||||
|
||||
if mapHostStatus("online") != StatusOnline {
|
||||
t.Fatalf("expected online host status")
|
||||
}
|
||||
if mapHostStatus("offline") != StatusOffline {
|
||||
t.Fatalf("expected offline host status")
|
||||
}
|
||||
if mapHostStatus("degraded") != StatusDegraded {
|
||||
t.Fatalf("expected degraded host status")
|
||||
}
|
||||
if mapHostStatus("unknown") != StatusUnknown {
|
||||
t.Fatalf("expected unknown host status")
|
||||
}
|
||||
|
||||
if mapDockerHostStatus("online") != StatusOnline {
|
||||
t.Fatalf("expected online docker host status")
|
||||
}
|
||||
if mapDockerHostStatus("offline") != StatusOffline {
|
||||
t.Fatalf("expected offline docker host status")
|
||||
}
|
||||
if mapDockerHostStatus("unknown") != StatusUnknown {
|
||||
t.Fatalf("expected unknown docker host status")
|
||||
}
|
||||
|
||||
if mapDockerContainerStatus("running") != StatusRunning {
|
||||
t.Fatalf("expected running container status")
|
||||
}
|
||||
if mapDockerContainerStatus("exited") != StatusStopped {
|
||||
t.Fatalf("expected exited container status")
|
||||
}
|
||||
if mapDockerContainerStatus("dead") != StatusStopped {
|
||||
t.Fatalf("expected dead container status")
|
||||
}
|
||||
if mapDockerContainerStatus("paused") != StatusPaused {
|
||||
t.Fatalf("expected paused container status")
|
||||
}
|
||||
if mapDockerContainerStatus("restarting") != StatusUnknown {
|
||||
t.Fatalf("expected restarting container status unknown")
|
||||
}
|
||||
if mapDockerContainerStatus("created") != StatusUnknown {
|
||||
t.Fatalf("expected created container status unknown")
|
||||
}
|
||||
if mapDockerContainerStatus("other") != StatusUnknown {
|
||||
t.Fatalf("expected unknown container status")
|
||||
}
|
||||
|
||||
if mapKubernetesClusterStatus(" online ") != StatusOnline {
|
||||
t.Fatalf("expected online k8s cluster status")
|
||||
}
|
||||
if mapKubernetesClusterStatus("offline") != StatusOffline {
|
||||
t.Fatalf("expected offline k8s cluster status")
|
||||
}
|
||||
if mapKubernetesClusterStatus("unknown") != StatusUnknown {
|
||||
t.Fatalf("expected unknown k8s cluster status")
|
||||
}
|
||||
|
||||
if mapKubernetesPodStatus("running") != StatusRunning {
|
||||
t.Fatalf("expected running pod status")
|
||||
}
|
||||
if mapKubernetesPodStatus("succeeded") != StatusStopped {
|
||||
t.Fatalf("expected succeeded pod status")
|
||||
}
|
||||
if mapKubernetesPodStatus("failed") != StatusStopped {
|
||||
t.Fatalf("expected failed pod status")
|
||||
}
|
||||
if mapKubernetesPodStatus("pending") != StatusUnknown {
|
||||
t.Fatalf("expected pending pod status")
|
||||
}
|
||||
if mapKubernetesPodStatus("unknown") != StatusUnknown {
|
||||
t.Fatalf("expected unknown pod status")
|
||||
}
|
||||
|
||||
if mapPBSStatus("online", "healthy") != StatusOnline {
|
||||
t.Fatalf("expected online PBS status")
|
||||
}
|
||||
if mapPBSStatus("offline", "healthy") != StatusOffline {
|
||||
t.Fatalf("expected offline PBS status")
|
||||
}
|
||||
if mapPBSStatus("other", "healthy") != StatusUnknown {
|
||||
t.Fatalf("expected unknown PBS status")
|
||||
}
|
||||
if mapPBSStatus("online", "unhealthy") != StatusDegraded {
|
||||
t.Fatalf("expected degraded PBS status")
|
||||
}
|
||||
}
|
||||
834
internal/resources/coverage_test.go
Normal file
834
internal/resources/coverage_test.go
Normal file
@@ -0,0 +1,834 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
)
|
||||
|
||||
func TestResourcePlatformDataAndMetrics(t *testing.T) {
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
r := Resource{Name: "name"}
|
||||
if err := r.GetPlatformData(&out); err != nil {
|
||||
t.Fatalf("expected nil error for empty platform data, got %v", err)
|
||||
}
|
||||
|
||||
if err := r.SetPlatformData(make(chan int)); err == nil {
|
||||
t.Fatal("expected error for unserializable platform data")
|
||||
}
|
||||
|
||||
if err := r.SetPlatformData(struct {
|
||||
Value string `json:"value"`
|
||||
}{Value: "ok"}); err != nil {
|
||||
t.Fatalf("unexpected error setting platform data: %v", err)
|
||||
}
|
||||
if err := r.GetPlatformData(&out); err != nil {
|
||||
t.Fatalf("unexpected error getting platform data: %v", err)
|
||||
}
|
||||
if out.Value != "ok" {
|
||||
t.Fatalf("expected platform data value to be ok, got %q", out.Value)
|
||||
}
|
||||
|
||||
r.PlatformData = []byte("{")
|
||||
if err := r.GetPlatformData(&out); err == nil {
|
||||
t.Fatal("expected error for invalid platform data")
|
||||
}
|
||||
|
||||
r.DisplayName = "display"
|
||||
if r.EffectiveDisplayName() != "display" {
|
||||
t.Fatalf("expected display name to be used")
|
||||
}
|
||||
r.DisplayName = ""
|
||||
if r.EffectiveDisplayName() != "name" {
|
||||
t.Fatalf("expected name fallback")
|
||||
}
|
||||
|
||||
if r.CPUPercent() != 0 {
|
||||
t.Fatalf("expected CPUPercent 0 for nil CPU")
|
||||
}
|
||||
r.CPU = &MetricValue{Current: 12.5}
|
||||
if r.CPUPercent() != 12.5 {
|
||||
t.Fatalf("expected CPUPercent 12.5")
|
||||
}
|
||||
|
||||
if r.MemoryPercent() != 0 {
|
||||
t.Fatalf("expected MemoryPercent 0 for nil memory")
|
||||
}
|
||||
total := int64(100)
|
||||
used := int64(25)
|
||||
r.Memory = &MetricValue{Total: &total, Used: &used}
|
||||
if r.MemoryPercent() != 25 {
|
||||
t.Fatalf("expected MemoryPercent 25")
|
||||
}
|
||||
r.Memory = &MetricValue{Current: 40}
|
||||
if r.MemoryPercent() != 40 {
|
||||
t.Fatalf("expected MemoryPercent 40")
|
||||
}
|
||||
|
||||
if r.DiskPercent() != 0 {
|
||||
t.Fatalf("expected DiskPercent 0 for nil disk")
|
||||
}
|
||||
totalDisk := int64(200)
|
||||
usedDisk := int64(50)
|
||||
r.Disk = &MetricValue{Total: &totalDisk, Used: &usedDisk}
|
||||
if r.DiskPercent() != 25 {
|
||||
t.Fatalf("expected DiskPercent 25")
|
||||
}
|
||||
r.Disk = &MetricValue{Current: 33}
|
||||
if r.DiskPercent() != 33 {
|
||||
t.Fatalf("expected DiskPercent 33")
|
||||
}
|
||||
|
||||
r.Type = ResourceTypeNode
|
||||
if !r.IsInfrastructure() {
|
||||
t.Fatalf("expected node to be infrastructure")
|
||||
}
|
||||
if r.IsWorkload() {
|
||||
t.Fatalf("expected node to not be workload")
|
||||
}
|
||||
r.Type = ResourceTypeVM
|
||||
if r.IsInfrastructure() {
|
||||
t.Fatalf("expected vm to not be infrastructure")
|
||||
}
|
||||
if !r.IsWorkload() {
|
||||
t.Fatalf("expected vm to be workload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePreferredAndSuppressed(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
existing := Resource{
|
||||
ID: "agent-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host1",
|
||||
},
|
||||
}
|
||||
store.Upsert(existing)
|
||||
|
||||
incoming := Resource{
|
||||
ID: "api-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceAPI,
|
||||
LastSeen: now.Add(1 * time.Minute),
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host1",
|
||||
},
|
||||
}
|
||||
preferred := store.Upsert(incoming)
|
||||
if preferred != existing.ID {
|
||||
t.Fatalf("expected preferred ID to be %s, got %s", existing.ID, preferred)
|
||||
}
|
||||
if !store.IsSuppressed(incoming.ID) {
|
||||
t.Fatalf("expected incoming to be suppressed")
|
||||
}
|
||||
if store.GetPreferredID(incoming.ID) != existing.ID {
|
||||
t.Fatalf("expected preferred ID to map to %s", existing.ID)
|
||||
}
|
||||
if store.GetPreferredID(existing.ID) != existing.ID {
|
||||
t.Fatalf("expected preferred ID to return itself")
|
||||
}
|
||||
|
||||
got, ok := store.Get(incoming.ID)
|
||||
if !ok || got.ID != existing.ID {
|
||||
t.Fatalf("expected Get to return preferred resource")
|
||||
}
|
||||
|
||||
if store.GetPreferredResourceFor(incoming.ID) == nil {
|
||||
t.Fatalf("expected preferred resource for suppressed ID")
|
||||
}
|
||||
if store.GetPreferredResourceFor(existing.ID) == nil {
|
||||
t.Fatalf("expected preferred resource for existing ID")
|
||||
}
|
||||
if store.GetPreferredResourceFor("missing") != nil {
|
||||
t.Fatalf("expected nil for missing resource")
|
||||
}
|
||||
|
||||
if !store.IsSamePhysicalMachine(existing.ID, incoming.ID) {
|
||||
t.Fatalf("expected IDs to be same physical machine")
|
||||
}
|
||||
if !store.IsSamePhysicalMachine(incoming.ID, existing.ID) {
|
||||
t.Fatalf("expected merged ID to match preferred")
|
||||
}
|
||||
if store.IsSamePhysicalMachine(existing.ID, "other") {
|
||||
t.Fatalf("expected different IDs to not match")
|
||||
}
|
||||
if !store.IsSamePhysicalMachine(existing.ID, existing.ID) {
|
||||
t.Fatalf("expected same ID to match")
|
||||
}
|
||||
|
||||
store.Remove(existing.ID)
|
||||
if store.IsSuppressed(incoming.ID) {
|
||||
t.Fatalf("expected suppression to be cleared after removal")
|
||||
}
|
||||
if _, ok := store.Get(incoming.ID); ok {
|
||||
t.Fatalf("expected Get to fail for removed preferred resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreStatsAndHelpers(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "host-1",
|
||||
Type: ResourceTypeHost,
|
||||
Status: StatusOffline,
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
Alerts: []ResourceAlert{{ID: "a1"}},
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host-1",
|
||||
},
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "host-2",
|
||||
Type: ResourceTypeHost,
|
||||
Status: StatusOnline,
|
||||
SourceType: SourceAPI,
|
||||
LastSeen: now.Add(time.Minute),
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host-1",
|
||||
},
|
||||
})
|
||||
|
||||
stats := store.GetStats()
|
||||
if stats.SuppressedResources != 1 {
|
||||
t.Fatalf("expected 1 suppressed resource")
|
||||
}
|
||||
if stats.WithAlerts != 1 {
|
||||
t.Fatalf("expected 1 resource with alerts")
|
||||
}
|
||||
|
||||
if store.sourceScore(SourceType("other")) != 0 {
|
||||
t.Fatalf("expected default source score")
|
||||
}
|
||||
|
||||
a := &Resource{ID: "a", SourceType: SourceAPI, LastSeen: now}
|
||||
b := &Resource{ID: "b", SourceType: SourceAgent, LastSeen: now.Add(time.Second)}
|
||||
if store.preferredResource(a, b) != b {
|
||||
t.Fatalf("expected agent resource to be preferred")
|
||||
}
|
||||
c := &Resource{ID: "c", SourceType: SourceAPI, LastSeen: now.Add(time.Second)}
|
||||
if store.preferredResource(a, c) != c {
|
||||
t.Fatalf("expected newer resource to be preferred")
|
||||
}
|
||||
d := &Resource{ID: "d", SourceType: SourceAgent, LastSeen: now}
|
||||
e := &Resource{ID: "e", SourceType: SourceAPI, LastSeen: now}
|
||||
if store.preferredResource(d, e) != d {
|
||||
t.Fatalf("expected higher score resource to be preferred")
|
||||
}
|
||||
f := &Resource{ID: "f", SourceType: SourceAPI, LastSeen: now.Add(2 * time.Second)}
|
||||
g := &Resource{ID: "g", SourceType: SourceAPI, LastSeen: now.Add(1 * time.Second)}
|
||||
if store.preferredResource(f, g) != f {
|
||||
t.Fatalf("expected newer resource to be preferred when scores equal")
|
||||
}
|
||||
|
||||
if store.findDuplicate(&Resource{}) != "" {
|
||||
t.Fatalf("expected no duplicate for nil identity")
|
||||
}
|
||||
if store.findDuplicate(&Resource{Type: ResourceTypeVM, Identity: &ResourceIdentity{Hostname: "host-1"}}) != "" {
|
||||
t.Fatalf("expected no duplicate for workload type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorePreferredSourceAndPolling(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "api-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceAPI,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host1",
|
||||
},
|
||||
})
|
||||
if store.HasPreferredSourceForHostname("host1") {
|
||||
t.Fatalf("expected no preferred source for API-only hostname")
|
||||
}
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "hybrid-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceHybrid,
|
||||
LastSeen: now.Add(time.Second),
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host1",
|
||||
},
|
||||
})
|
||||
if !store.HasPreferredSourceForHostname("HOST1") {
|
||||
t.Fatalf("expected preferred source for hostname")
|
||||
}
|
||||
if store.HasPreferredSourceForHostname("") {
|
||||
t.Fatalf("expected empty hostname to be false")
|
||||
}
|
||||
if !store.ShouldSkipAPIPolling("host1") {
|
||||
t.Fatalf("expected skip polling for preferred source")
|
||||
}
|
||||
if store.HasPreferredSourceForHostname("missing") {
|
||||
t.Fatalf("expected missing hostname to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAgentHostnamesAndRecommendations(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "agent-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host1",
|
||||
},
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "hybrid-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceHybrid,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "HOST2",
|
||||
},
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "api-1",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceAPI,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host3",
|
||||
},
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "no-identity",
|
||||
Type: ResourceTypeHost,
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "node-agent",
|
||||
Type: ResourceTypeNode,
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "host1",
|
||||
},
|
||||
})
|
||||
|
||||
hostnames := store.GetAgentMonitoredHostnames()
|
||||
seen := make(map[string]bool)
|
||||
for _, h := range hostnames {
|
||||
seen[strings.ToLower(h)] = true
|
||||
}
|
||||
if !seen["host1"] || !seen["host2"] {
|
||||
t.Fatalf("expected host1 and host2 to be monitored, got %v", hostnames)
|
||||
}
|
||||
if len(seen) != 2 {
|
||||
t.Fatalf("expected two unique hostnames")
|
||||
}
|
||||
|
||||
recs := store.GetPollingRecommendations()
|
||||
if recs["host1"] != 0 {
|
||||
t.Fatalf("expected host1 recommendation 0, got %v", recs["host1"])
|
||||
}
|
||||
if recs["host2"] != 0.5 {
|
||||
t.Fatalf("expected host2 recommendation 0.5, got %v", recs["host2"])
|
||||
}
|
||||
if _, ok := recs["host3"]; ok {
|
||||
t.Fatalf("did not expect API-only host to have recommendation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFindContainerHost(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
host := Resource{
|
||||
ID: "docker-host-1",
|
||||
Type: ResourceTypeDockerHost,
|
||||
Name: "docker-host",
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
Identity: &ResourceIdentity{
|
||||
Hostname: "docker1",
|
||||
},
|
||||
}
|
||||
store.Upsert(host)
|
||||
store.Upsert(Resource{
|
||||
ID: "docker-host-1/container-1",
|
||||
Type: ResourceTypeDockerContainer,
|
||||
Name: "web-app",
|
||||
ParentID: "docker-host-1",
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
})
|
||||
|
||||
if store.FindContainerHost("") != "" {
|
||||
t.Fatalf("expected empty search to return empty")
|
||||
}
|
||||
if store.FindContainerHost("missing") != "" {
|
||||
t.Fatalf("expected missing container to return empty")
|
||||
}
|
||||
if store.FindContainerHost("web-app") != "docker1" {
|
||||
t.Fatalf("expected host name from identity")
|
||||
}
|
||||
if store.FindContainerHost("CONTAINER-1") != "docker1" {
|
||||
t.Fatalf("expected match by ID")
|
||||
}
|
||||
if store.FindContainerHost("web") != "docker1" {
|
||||
t.Fatalf("expected match by substring")
|
||||
}
|
||||
|
||||
store2 := NewStore()
|
||||
store2.Upsert(Resource{
|
||||
ID: "container-2",
|
||||
Type: ResourceTypeDockerContainer,
|
||||
Name: "db",
|
||||
ParentID: "missing-host",
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
})
|
||||
if store2.FindContainerHost("db") != "" {
|
||||
t.Fatalf("expected missing parent to return empty")
|
||||
}
|
||||
|
||||
store3 := NewStore()
|
||||
store3.Upsert(Resource{
|
||||
ID: "host-no-identity",
|
||||
Type: ResourceTypeDockerHost,
|
||||
Name: "host-name",
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
})
|
||||
store3.Upsert(Resource{
|
||||
ID: "host-no-identity/container-3",
|
||||
Type: ResourceTypeDockerContainer,
|
||||
Name: "cache",
|
||||
ParentID: "host-no-identity",
|
||||
SourceType: SourceAgent,
|
||||
LastSeen: now,
|
||||
})
|
||||
if store3.FindContainerHost("cache") != "host-name" {
|
||||
t.Fatalf("expected fallback to host name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceQueryFiltersAndSorting(t *testing.T) {
|
||||
store := NewStore()
|
||||
base := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "r1",
|
||||
Type: ResourceTypeVM,
|
||||
Name: "b",
|
||||
Status: StatusRunning,
|
||||
ClusterID: "c1",
|
||||
CPU: &MetricValue{Current: 10},
|
||||
Memory: &MetricValue{Current: 60},
|
||||
Disk: &MetricValue{Current: 20},
|
||||
LastSeen: base.Add(1 * time.Hour),
|
||||
Alerts: []ResourceAlert{{ID: "a1"}},
|
||||
SourceType: SourceAPI,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "r2",
|
||||
Type: ResourceTypeNode,
|
||||
Name: "a",
|
||||
Status: StatusOnline,
|
||||
ClusterID: "c1",
|
||||
CPU: &MetricValue{Current: 50},
|
||||
Memory: &MetricValue{Current: 10},
|
||||
Disk: &MetricValue{Current: 90},
|
||||
LastSeen: base.Add(2 * time.Hour),
|
||||
SourceType: SourceAPI,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "r3",
|
||||
Type: ResourceTypeVM,
|
||||
Name: "c",
|
||||
Status: StatusOffline,
|
||||
ClusterID: "c2",
|
||||
CPU: &MetricValue{Current: 5},
|
||||
Memory: &MetricValue{Current: 30},
|
||||
Disk: &MetricValue{Current: 10},
|
||||
LastSeen: base.Add(3 * time.Hour),
|
||||
SourceType: SourceAPI,
|
||||
})
|
||||
|
||||
clustered := store.Query().InCluster("c1").Execute()
|
||||
if len(clustered) != 2 {
|
||||
t.Fatalf("expected 2 clustered resources, got %d", len(clustered))
|
||||
}
|
||||
|
||||
withAlerts := store.Query().WithAlerts().Execute()
|
||||
if len(withAlerts) != 1 || withAlerts[0].ID != "r1" {
|
||||
t.Fatalf("expected only r1 with alerts")
|
||||
}
|
||||
|
||||
sorted := store.Query().SortBy("name", false).Execute()
|
||||
if len(sorted) < 2 || sorted[0].Name != "a" {
|
||||
t.Fatalf("expected sorted results by name")
|
||||
}
|
||||
|
||||
limited := store.Query().SortBy("cpu", true).Offset(1).Limit(1).Execute()
|
||||
if len(limited) != 1 {
|
||||
t.Fatalf("expected limited results")
|
||||
}
|
||||
|
||||
empty := store.Query().Offset(10).Execute()
|
||||
if len(empty) != 0 {
|
||||
t.Fatalf("expected empty results for large offset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortResourcesFields(t *testing.T) {
|
||||
base := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
resources := []Resource{
|
||||
{
|
||||
ID: "r1",
|
||||
Type: ResourceTypeVM,
|
||||
Name: "b",
|
||||
Status: StatusRunning,
|
||||
CPU: &MetricValue{Current: 20},
|
||||
Memory: &MetricValue{Current: 40},
|
||||
Disk: &MetricValue{Current: 30},
|
||||
LastSeen: base.Add(1 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "r2",
|
||||
Type: ResourceTypeNode,
|
||||
Name: "a",
|
||||
Status: StatusOffline,
|
||||
CPU: &MetricValue{Current: 50},
|
||||
Memory: &MetricValue{Current: 10},
|
||||
Disk: &MetricValue{Current: 90},
|
||||
LastSeen: base.Add(2 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "r3",
|
||||
Type: ResourceTypeContainer,
|
||||
Name: "c",
|
||||
Status: StatusDegraded,
|
||||
CPU: &MetricValue{Current: 5},
|
||||
Memory: &MetricValue{Current: 80},
|
||||
Disk: &MetricValue{Current: 10},
|
||||
LastSeen: base.Add(3 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
single := []Resource{{ID: "only"}}
|
||||
sortResources(single, "name", false)
|
||||
|
||||
cases := []struct {
|
||||
field string
|
||||
desc bool
|
||||
want string
|
||||
}{
|
||||
{"name", false, "r2"},
|
||||
{"name", true, "r3"},
|
||||
{"type", false, "r3"},
|
||||
{"type", true, "r1"},
|
||||
{"status", false, "r3"},
|
||||
{"status", true, "r1"},
|
||||
{"cpu", true, "r2"},
|
||||
{"cpu", false, "r3"},
|
||||
{"memory", true, "r3"},
|
||||
{"memory", false, "r2"},
|
||||
{"disk", true, "r2"},
|
||||
{"disk", false, "r3"},
|
||||
{"last_seen", true, "r3"},
|
||||
{"last_seen", false, "r1"},
|
||||
{"lastseen", true, "r3"},
|
||||
{"mem", true, "r3"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
sorted := append([]Resource(nil), resources...)
|
||||
sortResources(sorted, tc.field, tc.desc)
|
||||
if sorted[0].ID != tc.want {
|
||||
t.Fatalf("sort %s desc=%v expected %s, got %s", tc.field, tc.desc, tc.want, sorted[0].ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTopByMemoryAndDisk(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "vm1",
|
||||
Type: ResourceTypeVM,
|
||||
Memory: &MetricValue{Current: 80},
|
||||
Disk: &MetricValue{Current: 30},
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "vm2",
|
||||
Type: ResourceTypeVM,
|
||||
Memory: &MetricValue{Current: 20},
|
||||
Disk: &MetricValue{Current: 90},
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "node1",
|
||||
Type: ResourceTypeNode,
|
||||
Memory: &MetricValue{Current: 60},
|
||||
Disk: &MetricValue{Current: 10},
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "skip-memory",
|
||||
Type: ResourceTypeVM,
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "skip-disk",
|
||||
Type: ResourceTypeVM,
|
||||
Disk: &MetricValue{Current: 0},
|
||||
LastSeen: now,
|
||||
})
|
||||
|
||||
topMem := store.GetTopByMemory(1, nil)
|
||||
if len(topMem) != 1 || topMem[0].ID != "vm1" {
|
||||
t.Fatalf("expected vm1 to be top memory")
|
||||
}
|
||||
topMemVMs := store.GetTopByMemory(10, []ResourceType{ResourceTypeVM})
|
||||
if len(topMemVMs) != 2 {
|
||||
t.Fatalf("expected 2 VM memory results, got %d", len(topMemVMs))
|
||||
}
|
||||
|
||||
topDisk := store.GetTopByDisk(1, nil)
|
||||
if len(topDisk) != 1 || topDisk[0].ID != "vm2" {
|
||||
t.Fatalf("expected vm2 to be top disk")
|
||||
}
|
||||
topDiskNodes := store.GetTopByDisk(10, []ResourceType{ResourceTypeNode})
|
||||
if len(topDiskNodes) != 1 || topDiskNodes[0].ID != "node1" {
|
||||
t.Fatalf("expected node1 to be top disk node")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTopByCPU_SkipsZero(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "skip-cpu-1",
|
||||
Type: ResourceTypeVM,
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "skip-cpu-2",
|
||||
Type: ResourceTypeVM,
|
||||
CPU: &MetricValue{Current: 0},
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "cpu-1",
|
||||
Type: ResourceTypeVM,
|
||||
CPU: &MetricValue{Current: 10},
|
||||
LastSeen: now,
|
||||
})
|
||||
|
||||
top := store.GetTopByCPU(10, nil)
|
||||
if len(top) != 1 || top[0].ID != "cpu-1" {
|
||||
t.Fatalf("expected only cpu-1 to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRelatedWithChildren(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "parent",
|
||||
Type: ResourceTypeNode,
|
||||
Name: "parent",
|
||||
LastSeen: now,
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "child-1",
|
||||
Type: ResourceTypeVM,
|
||||
Name: "child-1",
|
||||
ParentID: "parent",
|
||||
LastSeen: now,
|
||||
})
|
||||
|
||||
related := store.GetRelated("parent")
|
||||
if children, ok := related["children"]; !ok || len(children) != 1 {
|
||||
t.Fatalf("expected children for parent")
|
||||
}
|
||||
|
||||
if len(store.GetRelated("missing")) != 0 {
|
||||
t.Fatalf("expected no related resources for missing ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceSummaryWithDegradedAndAlerts(t *testing.T) {
|
||||
store := NewStore()
|
||||
now := time.Now()
|
||||
|
||||
store.Upsert(Resource{
|
||||
ID: "healthy",
|
||||
Type: ResourceTypeNode,
|
||||
Status: StatusOnline,
|
||||
LastSeen: now,
|
||||
CPU: &MetricValue{Current: 20},
|
||||
Memory: &MetricValue{Current: 50},
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "degraded",
|
||||
Type: ResourceTypeVM,
|
||||
Status: StatusDegraded,
|
||||
LastSeen: now,
|
||||
Alerts: []ResourceAlert{{ID: "a1"}},
|
||||
})
|
||||
store.Upsert(Resource{
|
||||
ID: "unknown",
|
||||
Type: ResourceTypeVM,
|
||||
Status: StatusUnknown,
|
||||
LastSeen: now,
|
||||
})
|
||||
|
||||
summary := store.GetResourceSummary()
|
||||
if summary.Degraded != 1 {
|
||||
t.Fatalf("expected degraded count 1")
|
||||
}
|
||||
if summary.Offline != 1 {
|
||||
t.Fatalf("expected offline count 1")
|
||||
}
|
||||
if summary.WithAlerts != 1 {
|
||||
t.Fatalf("expected alerts count 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopulateFromSnapshotFull(t *testing.T) {
|
||||
store := NewStore()
|
||||
store.Upsert(Resource{ID: "old-resource", Type: ResourceTypeNode, LastSeen: time.Now()})
|
||||
|
||||
now := time.Now()
|
||||
cluster := models.KubernetesCluster{
|
||||
ID: "cluster-1",
|
||||
AgentID: "agent-1",
|
||||
Status: "online",
|
||||
LastSeen: now,
|
||||
Nodes: []models.KubernetesNode{
|
||||
{Name: "node-1", Ready: true},
|
||||
},
|
||||
Pods: []models.KubernetesPod{
|
||||
{Name: "pod-1", Namespace: "default", Phase: "Running", NodeName: "node-1"},
|
||||
},
|
||||
Deployments: []models.KubernetesDeployment{
|
||||
{Name: "dep-1", Namespace: "default", DesiredReplicas: 1, AvailableReplicas: 1},
|
||||
},
|
||||
}
|
||||
dockerHost := models.DockerHost{
|
||||
ID: "docker-1",
|
||||
AgentID: "agent-docker",
|
||||
Hostname: "docker-host",
|
||||
Status: "online",
|
||||
Memory: models.Memory{
|
||||
Total: 100,
|
||||
Used: 50,
|
||||
Free: 50,
|
||||
Usage: 50,
|
||||
},
|
||||
Disks: []models.Disk{
|
||||
{Total: 100, Used: 60, Free: 40, Usage: 60},
|
||||
},
|
||||
NetworkInterfaces: []models.HostNetworkInterface{
|
||||
{Addresses: []string{"10.0.0.1"}, RXBytes: 1, TXBytes: 2},
|
||||
},
|
||||
Containers: []models.DockerContainer{
|
||||
{
|
||||
ID: "container-1",
|
||||
Name: "web",
|
||||
State: "running",
|
||||
Status: "Up",
|
||||
CPUPercent: 5,
|
||||
MemoryUsage: 50,
|
||||
MemoryLimit: 100,
|
||||
MemoryPercent: 50,
|
||||
},
|
||||
},
|
||||
LastSeen: now,
|
||||
}
|
||||
snapshot := models.StateSnapshot{
|
||||
DockerHosts: []models.DockerHost{dockerHost},
|
||||
KubernetesClusters: []models.KubernetesCluster{
|
||||
cluster,
|
||||
},
|
||||
PBSInstances: []models.PBSInstance{
|
||||
{
|
||||
ID: "pbs-1",
|
||||
Name: "pbs",
|
||||
Host: "pbs.local",
|
||||
Status: "online",
|
||||
ConnectionHealth: "healthy",
|
||||
CPU: 20,
|
||||
Memory: 50,
|
||||
MemoryTotal: 100,
|
||||
MemoryUsed: 50,
|
||||
Uptime: 10,
|
||||
LastSeen: now,
|
||||
},
|
||||
},
|
||||
Storage: []models.Storage{
|
||||
{
|
||||
ID: "storage-1",
|
||||
Name: "local",
|
||||
Instance: "pve1",
|
||||
Node: "node1",
|
||||
Total: 100,
|
||||
Used: 50,
|
||||
Free: 50,
|
||||
Usage: 50,
|
||||
Active: true,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store.PopulateFromSnapshot(snapshot)
|
||||
|
||||
if _, ok := store.Get("old-resource"); ok {
|
||||
t.Fatalf("expected old resource to be removed")
|
||||
}
|
||||
|
||||
if len(store.Query().OfType(ResourceTypeDockerHost).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 docker host")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypeDockerContainer).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 docker container")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypeK8sCluster).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 k8s cluster")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypeK8sNode).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 k8s node")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypePod).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 k8s pod")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypeK8sDeployment).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 k8s deployment")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypePBS).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 PBS instance")
|
||||
}
|
||||
if len(store.Query().OfType(ResourceTypeStorage).Execute()) != 1 {
|
||||
t.Fatalf("expected 1 storage resource")
|
||||
}
|
||||
}
|
||||
@@ -13,17 +13,25 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
execLookPath = exec.LookPath
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return exec.CommandContext(ctx, name, args...).Output()
|
||||
}
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
// DiskSMART represents S.M.A.R.T. data for a single disk.
|
||||
type DiskSMART struct {
|
||||
Device string `json:"device"` // Device path (e.g., /dev/sda)
|
||||
Model string `json:"model,omitempty"` // Disk model
|
||||
Serial string `json:"serial,omitempty"` // Serial number
|
||||
WWN string `json:"wwn,omitempty"` // World Wide Name
|
||||
Type string `json:"type,omitempty"` // Transport type: sata, sas, nvme
|
||||
Temperature int `json:"temperature"` // Temperature in Celsius
|
||||
Health string `json:"health,omitempty"` // PASSED, FAILED, UNKNOWN
|
||||
Standby bool `json:"standby,omitempty"` // True if disk was in standby
|
||||
LastUpdated time.Time `json:"lastUpdated"` // When this reading was taken
|
||||
Device string `json:"device"` // Device path (e.g., /dev/sda)
|
||||
Model string `json:"model,omitempty"` // Disk model
|
||||
Serial string `json:"serial,omitempty"` // Serial number
|
||||
WWN string `json:"wwn,omitempty"` // World Wide Name
|
||||
Type string `json:"type,omitempty"` // Transport type: sata, sas, nvme
|
||||
Temperature int `json:"temperature"` // Temperature in Celsius
|
||||
Health string `json:"health,omitempty"` // PASSED, FAILED, UNKNOWN
|
||||
Standby bool `json:"standby,omitempty"` // True if disk was in standby
|
||||
LastUpdated time.Time `json:"lastUpdated"` // When this reading was taken
|
||||
}
|
||||
|
||||
// smartctlJSON represents the JSON output from smartctl --json.
|
||||
@@ -85,8 +93,7 @@ func CollectLocal(ctx context.Context) ([]DiskSMART, error) {
|
||||
// listBlockDevices returns a list of block devices suitable for SMART queries.
|
||||
func listBlockDevices(ctx context.Context) ([]string, error) {
|
||||
// Use lsblk to find disks (not partitions)
|
||||
cmd := exec.CommandContext(ctx, "lsblk", "-d", "-n", "-o", "NAME,TYPE")
|
||||
output, err := cmd.Output()
|
||||
output, err := runCommandOutput(ctx, "lsblk", "-d", "-n", "-o", "NAME,TYPE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -114,7 +121,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
|
||||
defer cancel()
|
||||
|
||||
// Check if smartctl is available
|
||||
smartctlPath, err := exec.LookPath("smartctl")
|
||||
smartctlPath, err := execLookPath("smartctl")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -124,8 +131,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
|
||||
// -i: device info
|
||||
// -A: attributes (for temperature)
|
||||
// --json=o: output original smartctl JSON format
|
||||
cmd := exec.CommandContext(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device)
|
||||
output, err := cmd.Output()
|
||||
output, err := runCommandOutput(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device)
|
||||
|
||||
// smartctl returns non-zero exit codes for various conditions
|
||||
// Exit code 2 means drive is in standby - that's okay
|
||||
@@ -137,7 +143,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
|
||||
return &DiskSMART{
|
||||
Device: filepath.Base(device),
|
||||
Standby: true,
|
||||
LastUpdated: time.Now(),
|
||||
LastUpdated: timeNow(),
|
||||
}, nil
|
||||
}
|
||||
// Other exit codes might still have valid JSON output
|
||||
@@ -161,7 +167,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error)
|
||||
Model: smartData.ModelName,
|
||||
Serial: smartData.SerialNumber,
|
||||
Type: detectDiskType(smartData),
|
||||
LastUpdated: time.Now(),
|
||||
LastUpdated: timeNow(),
|
||||
}
|
||||
|
||||
// Build WWN string if available
|
||||
|
||||
303
internal/smartctl/collector_coverage_test.go
Normal file
303
internal/smartctl/collector_coverage_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package smartctl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestListBlockDevices(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
t.Cleanup(func() { runCommandOutput = origRun })
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name != "lsblk" {
|
||||
return nil, errors.New("unexpected command")
|
||||
}
|
||||
out := "sda disk\nsda1 part\nsr0 rom\nnvme0n1 disk\n\n"
|
||||
return []byte(out), nil
|
||||
}
|
||||
|
||||
devices, err := listBlockDevices(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("listBlockDevices error: %v", err)
|
||||
}
|
||||
if len(devices) != 2 || devices[0] != "/dev/sda" || devices[1] != "/dev/nvme0n1" {
|
||||
t.Fatalf("unexpected devices: %#v", devices)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListBlockDevicesError(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
t.Cleanup(func() { runCommandOutput = origRun })
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
if _, err := listBlockDevices(context.Background()); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRunCommandOutput(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
t.Cleanup(func() { runCommandOutput = origRun })
|
||||
runCommandOutput = origRun
|
||||
|
||||
if _, err := runCommandOutput(context.Background(), "sh", "-c", "printf ''"); err != nil {
|
||||
t.Fatalf("expected default runCommandOutput to work, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalNoDevices(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
t.Cleanup(func() { runCommandOutput = origRun })
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name == "lsblk" {
|
||||
return []byte(""), nil
|
||||
}
|
||||
return nil, errors.New("unexpected command")
|
||||
}
|
||||
|
||||
result, err := CollectLocal(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocal error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result for no devices")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalListDevicesError(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
t.Cleanup(func() { runCommandOutput = origRun })
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return nil, errors.New("lsblk failed")
|
||||
}
|
||||
|
||||
if _, err := CollectLocal(context.Background()); err == nil {
|
||||
t.Fatalf("expected list error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalSkipsErrors(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
origNow := timeNow
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
timeNow = origNow
|
||||
})
|
||||
|
||||
fixed := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
timeNow = func() time.Time { return fixed }
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name == "lsblk" {
|
||||
return []byte("sda disk\nsdb disk\n"), nil
|
||||
}
|
||||
if name == "smartctl" {
|
||||
device := args[len(args)-1]
|
||||
if strings.Contains(device, "sda") {
|
||||
return nil, errors.New("read error")
|
||||
}
|
||||
payload := smartctlJSON{
|
||||
ModelName: "Model",
|
||||
SerialNumber: "Serial",
|
||||
}
|
||||
payload.Device.Protocol = "ATA"
|
||||
payload.SmartStatus.Passed = true
|
||||
payload.Temperature.Current = 30
|
||||
out, _ := json.Marshal(payload)
|
||||
return out, nil
|
||||
}
|
||||
return nil, errors.New("unexpected command")
|
||||
}
|
||||
|
||||
result, err := CollectLocal(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("CollectLocal error: %v", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(result))
|
||||
}
|
||||
if result[0].Device != "sdb" || !result[0].LastUpdated.Equal(fixed) {
|
||||
t.Fatalf("unexpected result: %#v", result[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTLookPathError(t *testing.T) {
|
||||
origLook := execLookPath
|
||||
t.Cleanup(func() { execLookPath = origLook })
|
||||
execLookPath = func(string) (string, error) { return "", errors.New("missing") }
|
||||
|
||||
if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil {
|
||||
t.Fatalf("expected lookpath error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTStandby(t *testing.T) {
|
||||
if _, err := exec.LookPath("sh"); err != nil {
|
||||
t.Skip("sh not available")
|
||||
}
|
||||
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
origNow := timeNow
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
timeNow = origNow
|
||||
})
|
||||
|
||||
fixed := time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC)
|
||||
timeNow = func() time.Time { return fixed }
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return exec.CommandContext(ctx, "sh", "-c", "exit 2").Output()
|
||||
}
|
||||
|
||||
result, err := collectDeviceSMART(context.Background(), "/dev/sda")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == nil || !result.Standby || result.Device != "sda" || !result.LastUpdated.Equal(fixed) {
|
||||
t.Fatalf("unexpected standby result: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTExitErrorNoOutput(t *testing.T) {
|
||||
if _, err := exec.LookPath("sh"); err != nil {
|
||||
t.Skip("sh not available")
|
||||
}
|
||||
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
})
|
||||
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return exec.CommandContext(ctx, "sh", "-c", "exit 1").Output()
|
||||
}
|
||||
|
||||
if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil {
|
||||
t.Fatalf("expected error for exit code without output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTExitErrorWithOutput(t *testing.T) {
|
||||
if _, err := exec.LookPath("sh"); err != nil {
|
||||
t.Skip("sh not available")
|
||||
}
|
||||
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
})
|
||||
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
payload := `{"model_name":"Model","serial_number":"Serial","device":{"protocol":"ATA"},"smart_status":{"passed":false},"temperature":{"current":45}}`
|
||||
return exec.CommandContext(ctx, "sh", "-c", "echo '"+payload+"'; exit 1").Output()
|
||||
}
|
||||
|
||||
result, err := collectDeviceSMART(context.Background(), "/dev/sda")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == nil || result.Health != "FAILED" || result.Temperature != 45 {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTJSONError(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
})
|
||||
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte("{"), nil
|
||||
}
|
||||
|
||||
if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil {
|
||||
t.Fatalf("expected json error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTNVMeTempFallback(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
origNow := timeNow
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
timeNow = origNow
|
||||
})
|
||||
|
||||
fixed := time.Date(2024, 4, 5, 6, 7, 8, 0, time.UTC)
|
||||
timeNow = func() time.Time { return fixed }
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
payload := smartctlJSON{}
|
||||
payload.Device.Protocol = "NVMe"
|
||||
payload.NVMeSmartHealthInformationLog.Temperature = 55
|
||||
payload.SmartStatus.Passed = true
|
||||
out, _ := json.Marshal(payload)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
result, err := collectDeviceSMART(context.Background(), "/dev/nvme0n1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Temperature != 55 || result.Type != "nvme" || result.Health != "PASSED" || !result.LastUpdated.Equal(fixed) {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeviceSMARTWWN(t *testing.T) {
|
||||
origRun := runCommandOutput
|
||||
origLook := execLookPath
|
||||
t.Cleanup(func() {
|
||||
runCommandOutput = origRun
|
||||
execLookPath = origLook
|
||||
})
|
||||
|
||||
execLookPath = func(string) (string, error) { return "smartctl", nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
payload := smartctlJSON{}
|
||||
payload.WWN.NAA = 5
|
||||
payload.WWN.OUI = 0xabc
|
||||
payload.WWN.ID = 0x1234
|
||||
payload.Device.Protocol = "SAS"
|
||||
payload.SmartStatus.Passed = true
|
||||
out, _ := json.Marshal(payload)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
result, err := collectDeviceSMART(context.Background(), "/dev/sda")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.WWN != "5-abc-1234" {
|
||||
t.Fatalf("unexpected WWN: %q", result.WWN)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -44,12 +45,33 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
mkdirAllFn = os.MkdirAll
|
||||
statFn = os.Stat
|
||||
openFileFn = os.OpenFile
|
||||
openFn = os.Open
|
||||
appendOpenFileFn = func(path string) (io.WriteCloser, error) {
|
||||
return openFileFn(path, os.O_APPEND|os.O_WRONLY, 0o600)
|
||||
}
|
||||
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||
cmd := exec.CommandContext(ctx, "ssh-keyscan", args...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
// ErrNoHostKeys is returned when ssh-keyscan yields no usable entries.
|
||||
ErrNoHostKeys = errors.New("knownhosts: no host keys discovered")
|
||||
// ErrHostKeyChanged signals that a host key already exists with a different fingerprint.
|
||||
ErrHostKeyChanged = errors.New("knownhosts: host key changed")
|
||||
)
|
||||
|
||||
var (
|
||||
defaultMkdirAllFn = mkdirAllFn
|
||||
defaultStatFn = statFn
|
||||
defaultOpenFileFn = openFileFn
|
||||
defaultOpenFn = openFn
|
||||
defaultAppendOpenFileFn = appendOpenFileFn
|
||||
defaultKeyscanCmdRunner = keyscanCmdRunner
|
||||
)
|
||||
|
||||
// HostKeyChangeError describes a detected host key mismatch.
|
||||
type HostKeyChangeError struct {
|
||||
Host string
|
||||
@@ -214,17 +236,17 @@ func (m *manager) Path() string {
|
||||
|
||||
func (m *manager) ensureKnownHostsFile() error {
|
||||
dir := filepath.Dir(m.path)
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
if err := mkdirAllFn(dir, 0o700); err != nil {
|
||||
return fmt.Errorf("knownhosts: mkdir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(m.path); err == nil {
|
||||
if _, err := statFn(m.path); err == nil {
|
||||
return nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(m.path, os.O_CREATE|os.O_WRONLY, 0o600)
|
||||
f, err := openFileFn(m.path, os.O_CREATE|os.O_WRONLY, 0o600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("knownhosts: create %s: %w", m.path, err)
|
||||
}
|
||||
@@ -232,7 +254,7 @@ func (m *manager) ensureKnownHostsFile() error {
|
||||
}
|
||||
|
||||
func appendHostKey(path string, entries [][]byte) error {
|
||||
f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0o600)
|
||||
f, err := appendOpenFileFn(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("knownhosts: open %s: %w", path, err)
|
||||
}
|
||||
@@ -287,7 +309,7 @@ func normalizeHostEntry(host string, entry []byte) ([]byte, string, error) {
|
||||
}
|
||||
|
||||
func findHostKeyLine(path, host, keyType string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
f, err := openFn(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
@@ -328,10 +350,6 @@ func hostLineMatches(host, line string) bool {
|
||||
}
|
||||
|
||||
fields := strings.Fields(trimmed)
|
||||
if len(fields) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return hostFieldMatches(host, fields[0])
|
||||
}
|
||||
|
||||
@@ -396,8 +414,7 @@ func defaultKeyscan(ctx context.Context, host string, port int, timeout time.Dur
|
||||
}
|
||||
args = append(args, host)
|
||||
|
||||
cmd := exec.CommandContext(scanCtx, "ssh-keyscan", args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
output, err := keyscanCmdRunner(scanCtx, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w (output: %s)", err, strings.TrimSpace(string(output)))
|
||||
}
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
package knownhosts
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func resetKnownHostsFns() {
|
||||
mkdirAllFn = defaultMkdirAllFn
|
||||
statFn = defaultStatFn
|
||||
openFileFn = defaultOpenFileFn
|
||||
openFn = defaultOpenFn
|
||||
appendOpenFileFn = defaultAppendOpenFileFn
|
||||
keyscanCmdRunner = defaultKeyscanCmdRunner
|
||||
}
|
||||
|
||||
func TestEnsureCreatesFileAndCaches(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
@@ -110,6 +122,167 @@ func TestEnsureRespectsContextCancellation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManagerEmptyPath(t *testing.T) {
|
||||
if _, err := NewManager(""); err == nil {
|
||||
t.Fatal("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithPortMissingHost(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
mgr, err := NewManager(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithPort(context.Background(), "", 22); err == nil {
|
||||
t.Fatal("expected error for missing host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithPortDefaultsPort(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
var gotPort int
|
||||
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
||||
gotPort = port
|
||||
return []byte(host + " ssh-ed25519 AAAA"), nil
|
||||
}
|
||||
|
||||
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithPort(context.Background(), "example.com", 0); err != nil {
|
||||
t.Fatalf("EnsureWithPort: %v", err)
|
||||
}
|
||||
if gotPort != 22 {
|
||||
t.Fatalf("expected port 22, got %d", gotPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithPortCustomPort(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
||||
return []byte("[example.com]:2222 ssh-ed25519 AAAA"), nil
|
||||
}
|
||||
|
||||
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithPort(context.Background(), "example.com", 2222); err != nil {
|
||||
t.Fatalf("EnsureWithPort: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile: %v", err)
|
||||
}
|
||||
if got := strings.TrimSpace(string(data)); got != "[example.com]:2222 ssh-ed25519 AAAA" {
|
||||
t.Fatalf("unexpected known_hosts contents: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithPortKeyscanError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
||||
return nil, errors.New("scan failed")
|
||||
}
|
||||
|
||||
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithPort(context.Background(), "example.com", 22); err == nil {
|
||||
t.Fatal("expected keyscan error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesMissingHost(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
mgr, err := NewManager(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
||||
t.Fatal("expected missing host error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesNoEntries(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
mgr, err := NewManager(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, nil); err == nil {
|
||||
t.Fatal("expected no entries error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesNormalizeError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
mgr, err := NewManager(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("invalid")}); err == nil {
|
||||
t.Fatal("expected normalize error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesEnsureKnownHostsFileError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
mkdirAllFn = func(string, os.FileMode) error {
|
||||
return errors.New("mkdir failed")
|
||||
}
|
||||
|
||||
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
||||
t.Fatal("expected ensureKnownHostsFile error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesFindHostKeyLineError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
openFn = func(string) (*os.File, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
||||
t.Fatal("expected open error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostCandidates(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -293,3 +466,312 @@ func TestHostFieldMatches(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureKnownHostsFileMkdirError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
mkdirAllFn = func(string, os.FileMode) error {
|
||||
return errors.New("mkdir failed")
|
||||
}
|
||||
|
||||
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
|
||||
if err := m.ensureKnownHostsFile(); err == nil {
|
||||
t.Fatal("expected mkdir error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureKnownHostsFileStatError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
statFn = func(string) (os.FileInfo, error) {
|
||||
return nil, errors.New("stat failed")
|
||||
}
|
||||
|
||||
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
|
||||
if err := m.ensureKnownHostsFile(); err == nil {
|
||||
t.Fatal("expected stat error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureKnownHostsFileCreateError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
statFn = func(string) (os.FileInfo, error) {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
|
||||
if err := m.ensureKnownHostsFile(); err == nil {
|
||||
t.Fatal("expected create error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendHostKeyOpenError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
||||
t.Fatal("expected open error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendHostKeyWriteError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
||||
return errWriteCloser{err: errors.New("write failed")}, nil
|
||||
}
|
||||
|
||||
if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
||||
t.Fatal("expected write error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeHostEntryWithComment(t *testing.T) {
|
||||
entry := []byte("example.com ssh-ed25519 AAAA comment here")
|
||||
normalized, keyType, err := normalizeHostEntry("example.com", entry)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeHostEntry error: %v", err)
|
||||
}
|
||||
if keyType != "ssh-ed25519" {
|
||||
t.Fatalf("expected key type ssh-ed25519, got %s", keyType)
|
||||
}
|
||||
if string(normalized) != "example.com ssh-ed25519 AAAA comment here" {
|
||||
t.Fatalf("unexpected normalized entry: %s", string(normalized))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindHostKeyLineNotExists(t *testing.T) {
|
||||
line, err := findHostKeyLine(filepath.Join(t.TempDir(), "missing"), "example.com", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if line != "" {
|
||||
t.Fatalf("expected empty line, got %q", line)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindHostKeyLineOpenError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
openFn = func(string) (*os.File, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
if _, err := findHostKeyLine("ignored", "example.com", ""); err == nil {
|
||||
t.Fatal("expected open error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindHostKeyLineScannerError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
longLine := strings.Repeat("a", 70000)
|
||||
if err := os.WriteFile(path, []byte(longLine+"\n"), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
if _, err := findHostKeyLine(path, "example.com", ""); err == nil {
|
||||
t.Fatal("expected scanner error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostLineMatchesSkips(t *testing.T) {
|
||||
if hostLineMatches("example.com", "") {
|
||||
t.Fatal("expected empty line to be false")
|
||||
}
|
||||
if hostLineMatches("example.com", "# comment") {
|
||||
t.Fatal("expected comment line to be false")
|
||||
}
|
||||
if hostLineMatches("example.com", "|1|hash|salt ssh-ed25519 AAAA") {
|
||||
t.Fatal("expected hashed entry to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultKeyscanSuccess(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||
return []byte("example.com ssh-ed25519 AAAA"), nil
|
||||
}
|
||||
|
||||
out, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("defaultKeyscan error: %v", err)
|
||||
}
|
||||
if string(out) != "example.com ssh-ed25519 AAAA" {
|
||||
t.Fatalf("unexpected output: %s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultKeyscanError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||
return []byte("boom"), errors.New("scan failed")
|
||||
}
|
||||
|
||||
if _, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesDefaultsPort(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
|
||||
mgr, err := NewManager(path)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 0, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err != nil {
|
||||
t.Fatalf("EnsureWithEntries: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile: %v", err)
|
||||
}
|
||||
if got := strings.TrimSpace(string(data)); got != "example.com ssh-ed25519 AAAA" {
|
||||
t.Fatalf("unexpected known_hosts contents: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWithEntriesAppendError(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
||||
t.Fatal("expected append error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendHostKeySkipsEmptyEntry(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
buf := &bufferWriteCloser{}
|
||||
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
if err := appendHostKey("ignored", [][]byte{nil, []byte("example.com ssh-ed25519 AAAA")}); err != nil {
|
||||
t.Fatalf("appendHostKey error: %v", err)
|
||||
}
|
||||
if !strings.Contains(buf.String(), "example.com ssh-ed25519 AAAA") {
|
||||
t.Fatalf("expected entry to be written, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindHostKeyLineSkipsInvalidLines(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "known_hosts")
|
||||
contents := strings.Join([]string{
|
||||
"other.com ssh-ed25519 AAAA",
|
||||
"example.com ssh-ed25519",
|
||||
"example.com ssh-ed25519 AAAA",
|
||||
}, "\n") + "\n"
|
||||
if err := os.WriteFile(path, []byte(contents), 0600); err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
line, err := findHostKeyLine(path, "example.com", "ssh-ed25519")
|
||||
if err != nil {
|
||||
t.Fatalf("findHostKeyLine error: %v", err)
|
||||
}
|
||||
if line != "example.com ssh-ed25519 AAAA" {
|
||||
t.Fatalf("unexpected line: %q", line)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultKeyscanArgs(t *testing.T) {
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
|
||||
var gotArgs []string
|
||||
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||
gotArgs = append([]string{}, args...)
|
||||
return []byte("ok"), nil
|
||||
}
|
||||
|
||||
if _, err := defaultKeyscan(context.Background(), "example.com", 0, 0); err != nil {
|
||||
t.Fatalf("defaultKeyscan error: %v", err)
|
||||
}
|
||||
for _, arg := range gotArgs {
|
||||
if arg == "-p" {
|
||||
t.Fatal("did not expect -p for default port")
|
||||
}
|
||||
}
|
||||
if len(gotArgs) < 3 || gotArgs[len(gotArgs)-1] != "example.com" {
|
||||
t.Fatalf("unexpected args: %v", gotArgs)
|
||||
}
|
||||
|
||||
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
||||
gotArgs = append([]string{}, args...)
|
||||
return []byte("ok"), nil
|
||||
}
|
||||
if _, err := defaultKeyscan(context.Background(), "example.com", 2222, time.Second); err != nil {
|
||||
t.Fatalf("defaultKeyscan error: %v", err)
|
||||
}
|
||||
hasPort := false
|
||||
for i := 0; i < len(gotArgs)-1; i++ {
|
||||
if gotArgs[i] == "-p" && gotArgs[i+1] == "2222" {
|
||||
hasPort = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasPort {
|
||||
t.Fatalf("expected -p 2222 in args, got %v", gotArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyscanCmdRunnerDefault(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("ssh-keyscan helper script requires sh")
|
||||
}
|
||||
t.Cleanup(resetKnownHostsFns)
|
||||
|
||||
dir := t.TempDir()
|
||||
scriptPath := filepath.Join(dir, "ssh-keyscan")
|
||||
script := []byte("#!/bin/sh\necho example.com ssh-ed25519 AAAA\n")
|
||||
if err := os.WriteFile(scriptPath, script, 0700); err != nil {
|
||||
t.Fatalf("failed to write script: %v", err)
|
||||
}
|
||||
|
||||
oldPath := os.Getenv("PATH")
|
||||
if err := os.Setenv("PATH", dir+string(os.PathListSeparator)+oldPath); err != nil {
|
||||
t.Fatalf("failed to set PATH: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) })
|
||||
|
||||
output, err := keyscanCmdRunner(context.Background(), "example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("keyscanCmdRunner error: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(string(output)) != "example.com ssh-ed25519 AAAA" {
|
||||
t.Fatalf("unexpected output: %s", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
type errWriteCloser struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errWriteCloser) Write(p []byte) (int, error) {
|
||||
return 0, e.err
|
||||
}
|
||||
|
||||
func (e errWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type bufferWriteCloser struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *bufferWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,13 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
envGetFn = os.Getenv
|
||||
statFn = os.Stat
|
||||
readFileFn = os.ReadFile
|
||||
hostnameFn = os.Hostname
|
||||
)
|
||||
|
||||
var containerMarkers = []string{
|
||||
"docker",
|
||||
"lxc",
|
||||
@@ -19,24 +26,24 @@ var containerMarkers = []string{
|
||||
// InContainer reports whether Pulse is running inside a containerised environment.
|
||||
func InContainer() bool {
|
||||
// Allow operators to force container behaviour when automatic detection falls short.
|
||||
if isTruthy(os.Getenv("PULSE_FORCE_CONTAINER")) {
|
||||
if isTruthy(envGetFn("PULSE_FORCE_CONTAINER")) {
|
||||
return true
|
||||
}
|
||||
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
if _, err := statFn("/.dockerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := os.Stat("/run/.containerenv"); err == nil {
|
||||
if _, err := statFn("/run/.containerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check common environment hints provided by systemd/nspawn, LXC, etc.
|
||||
if val := strings.ToLower(strings.TrimSpace(os.Getenv("container"))); val != "" && val != "host" {
|
||||
if val := strings.ToLower(strings.TrimSpace(envGetFn("container"))); val != "" && val != "host" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Some distros expose the container hint through PID 1's environment.
|
||||
if data, err := os.ReadFile("/proc/1/environ"); err == nil {
|
||||
if data, err := readFileFn("/proc/1/environ"); err == nil {
|
||||
lower := strings.ToLower(string(data))
|
||||
if strings.Contains(lower, "container=") && !strings.Contains(lower, "container=host") {
|
||||
return true
|
||||
@@ -44,7 +51,7 @@ func InContainer() bool {
|
||||
}
|
||||
|
||||
// Fall back to cgroup inspection which covers older Docker/LXC setups.
|
||||
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
|
||||
if data, err := readFileFn("/proc/1/cgroup"); err == nil {
|
||||
content := strings.ToLower(string(data))
|
||||
for _, marker := range containerMarkers {
|
||||
if strings.Contains(content, marker) {
|
||||
@@ -60,7 +67,7 @@ func InContainer() bool {
|
||||
// Returns empty string if not in Docker or name cannot be determined.
|
||||
func DetectDockerContainerName() string {
|
||||
// Method 1: Check hostname (Docker uses container ID or name as hostname)
|
||||
if hostname, err := os.Hostname(); err == nil && hostname != "" {
|
||||
if hostname, err := hostnameFn(); err == nil && hostname != "" {
|
||||
// Docker hostnames are either short container ID (12 chars) or custom name
|
||||
// If it looks like a container ID (hex), skip it - user needs to use name
|
||||
if !isHexString(hostname) || len(hostname) > 12 {
|
||||
@@ -69,7 +76,7 @@ func DetectDockerContainerName() string {
|
||||
}
|
||||
|
||||
// Method 2: Try reading from /proc/self/cgroup
|
||||
if data, err := os.ReadFile("/proc/self/cgroup"); err == nil {
|
||||
if data, err := readFileFn("/proc/self/cgroup"); err == nil {
|
||||
// Look for patterns like: 0::/docker/<container-id>
|
||||
// But we can't get name from cgroup, only ID
|
||||
_ = data // placeholder for future enhancement
|
||||
@@ -91,7 +98,7 @@ func isHexString(s string) bool {
|
||||
// Returns empty string if not in an LXC container or CTID cannot be determined.
|
||||
func DetectLXCCTID() string {
|
||||
// Method 1: Parse /proc/1/cgroup for LXC container ID
|
||||
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
|
||||
if data, err := readFileFn("/proc/1/cgroup"); err == nil {
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, line := range lines {
|
||||
// Look for patterns like: 0::/lxc/123 or 0::/lxc.payload.123
|
||||
@@ -119,7 +126,7 @@ func DetectLXCCTID() string {
|
||||
}
|
||||
|
||||
// Method 2: Check hostname (some LXC containers use CTID as hostname)
|
||||
if hostname, err := os.Hostname(); err == nil && isNumeric(hostname) {
|
||||
if hostname, err := hostnameFn(); err == nil && isNumeric(hostname) {
|
||||
return hostname
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func resetSystemFns() {
|
||||
envGetFn = os.Getenv
|
||||
statFn = os.Stat
|
||||
readFileFn = os.ReadFile
|
||||
hostnameFn = os.Hostname
|
||||
}
|
||||
|
||||
func TestIsHexString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -133,20 +142,226 @@ func TestContainerMarkers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInContainer(t *testing.T) {
|
||||
// This is an integration test that depends on the runtime environment
|
||||
// We can't mock the file system easily, but we can verify it doesn't panic
|
||||
result := InContainer()
|
||||
t.Logf("InContainer() = %v (depends on test environment)", result)
|
||||
t.Run("Forced", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(key string) string {
|
||||
if key == "PULSE_FORCE_CONTAINER" {
|
||||
return "true"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
if !InContainer() {
|
||||
t.Fatal("expected forced container")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerEnvFile", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(string) string { return "" }
|
||||
statFn = func(path string) (os.FileInfo, error) {
|
||||
if path == "/.dockerenv" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") }
|
||||
|
||||
if !InContainer() {
|
||||
t.Fatal("expected docker env file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContainerEnvFile", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(string) string { return "" }
|
||||
statFn = func(path string) (os.FileInfo, error) {
|
||||
if path == "/run/.containerenv" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") }
|
||||
|
||||
if !InContainer() {
|
||||
t.Fatal("expected container env file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContainerEnvVar", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(key string) string {
|
||||
if key == "container" {
|
||||
return "lxc"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
|
||||
readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") }
|
||||
|
||||
if !InContainer() {
|
||||
t.Fatal("expected container env var")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcEnviron", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(string) string { return "" }
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
|
||||
readFileFn = func(path string) ([]byte, error) {
|
||||
if path == "/proc/1/environ" {
|
||||
return []byte("container=lxc\x00"), nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
|
||||
if !InContainer() {
|
||||
t.Fatal("expected container environ")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CgroupMarker", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(string) string { return "" }
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
|
||||
readFileFn = func(path string) ([]byte, error) {
|
||||
if path == "/proc/1/environ" {
|
||||
return []byte("container=host"), nil
|
||||
}
|
||||
if path == "/proc/1/cgroup" {
|
||||
return []byte("0::/docker/abc"), nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
|
||||
if !InContainer() {
|
||||
t.Fatal("expected container cgroup marker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotContainer", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
envGetFn = func(string) string { return "" }
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") }
|
||||
readFileFn = func(path string) ([]byte, error) {
|
||||
if path == "/proc/1/environ" {
|
||||
return []byte("container=host"), nil
|
||||
}
|
||||
if path == "/proc/1/cgroup" {
|
||||
return []byte("0::/user.slice"), nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
|
||||
if InContainer() {
|
||||
t.Fatal("expected non-container")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetectDockerContainerName(t *testing.T) {
|
||||
// This is an integration test that depends on the runtime environment
|
||||
result := DetectDockerContainerName()
|
||||
t.Logf("DetectDockerContainerName() = %q (depends on test environment)", result)
|
||||
t.Run("HostnameName", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
hostnameFn = func() (string, error) { return "my-container", nil }
|
||||
|
||||
if got := DetectDockerContainerName(); got != "my-container" {
|
||||
t.Fatalf("expected hostname name, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HostnameHexShort", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
hostnameFn = func() (string, error) { return "abcdef123456", nil }
|
||||
readFileFn = func(string) ([]byte, error) { return []byte("0::/docker/abcdef"), nil }
|
||||
|
||||
if got := DetectDockerContainerName(); got != "" {
|
||||
t.Fatalf("expected empty name, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HostnameHexLong", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
hostnameFn = func() (string, error) { return "abcdef1234567890abcdef1234567890", nil }
|
||||
|
||||
if got := DetectDockerContainerName(); got == "" {
|
||||
t.Fatal("expected hostname for long hex")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HostnameError", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
hostnameFn = func() (string, error) { return "", errors.New("fail") }
|
||||
readFileFn = func(string) ([]byte, error) { return []byte("0::/docker/abcdef"), nil }
|
||||
|
||||
if got := DetectDockerContainerName(); got != "" {
|
||||
t.Fatalf("expected empty name, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetectLXCCTID(t *testing.T) {
|
||||
// This is an integration test that depends on the runtime environment
|
||||
result := DetectLXCCTID()
|
||||
t.Logf("DetectLXCCTID() = %q (depends on test environment)", result)
|
||||
t.Run("FromCgroupLXC", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
readFileFn = func(path string) ([]byte, error) {
|
||||
if path == "/proc/1/cgroup" {
|
||||
return []byte("0::/lxc/123"), nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
hostnameFn = func() (string, error) { return "999", nil }
|
||||
|
||||
if got := DetectLXCCTID(); got != "123" {
|
||||
t.Fatalf("expected CTID 123, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FromCgroupPayload", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
readFileFn = func(path string) ([]byte, error) {
|
||||
if path == "/proc/1/cgroup" {
|
||||
return []byte("0::/lxc.payload.456"), nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
hostnameFn = func() (string, error) { return "999", nil }
|
||||
|
||||
if got := DetectLXCCTID(); got != "456" {
|
||||
t.Fatalf("expected CTID 456, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FromMachineLXC", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
readFileFn = func(path string) ([]byte, error) {
|
||||
if path == "/proc/1/cgroup" {
|
||||
return []byte("0::/machine.slice/machine-lxc-789"), nil
|
||||
}
|
||||
return nil, errors.New("missing")
|
||||
}
|
||||
hostnameFn = func() (string, error) { return "999", nil }
|
||||
|
||||
if got := DetectLXCCTID(); got != "789" {
|
||||
t.Fatalf("expected CTID 789, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FromHostname", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
readFileFn = func(string) ([]byte, error) { return []byte("0::/user.slice"), nil }
|
||||
hostnameFn = func() (string, error) { return "321", nil }
|
||||
|
||||
if got := DetectLXCCTID(); got != "321" {
|
||||
t.Fatalf("expected CTID 321, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch", func(t *testing.T) {
|
||||
t.Cleanup(resetSystemFns)
|
||||
readFileFn = func(string) ([]byte, error) { return []byte("0::/user.slice"), nil }
|
||||
hostnameFn = func() (string, error) { return "host", nil }
|
||||
|
||||
if got := DetectLXCCTID(); got != "" {
|
||||
t.Fatalf("expected empty CTID, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,6 +24,13 @@ const (
|
||||
maxRetries = 3
|
||||
)
|
||||
|
||||
var statFn = os.Stat
|
||||
|
||||
var dialContextFn = func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
dialer := net.Dialer{Timeout: timeout}
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ErrorType classifies proxy errors for better error handling
|
||||
type ErrorType int
|
||||
|
||||
@@ -66,9 +73,9 @@ type Client struct {
|
||||
func NewClient() *Client {
|
||||
socketPath := os.Getenv("PULSE_SENSOR_PROXY_SOCKET")
|
||||
if socketPath == "" {
|
||||
if _, err := os.Stat(defaultSocketPath); err == nil {
|
||||
if _, err := statFn(defaultSocketPath); err == nil {
|
||||
socketPath = defaultSocketPath
|
||||
} else if _, err := os.Stat(containerSocketPath); err == nil {
|
||||
} else if _, err := statFn(containerSocketPath); err == nil {
|
||||
socketPath = containerSocketPath
|
||||
} else {
|
||||
socketPath = defaultSocketPath
|
||||
@@ -83,7 +90,7 @@ func NewClient() *Client {
|
||||
|
||||
// IsAvailable checks if the proxy is running and accessible
|
||||
func (c *Client) IsAvailable() bool {
|
||||
_, err := os.Stat(c.socketPath)
|
||||
_, err := statFn(c.socketPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -312,13 +319,8 @@ func (c *Client) callWithContext(ctx context.Context, method string, params map[
|
||||
|
||||
// callOnce sends a single RPC request without retries
|
||||
func (c *Client) callOnce(ctx context.Context, method string, params map[string]interface{}) (*RPCResponse, error) {
|
||||
// Create a dialer with context
|
||||
dialer := net.Dialer{
|
||||
Timeout: c.timeout,
|
||||
}
|
||||
|
||||
// Connect to unix socket with context
|
||||
conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
|
||||
conn, err := dialContextFn(ctx, "unix", c.socketPath, c.timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to proxy: %w", err)
|
||||
}
|
||||
@@ -366,10 +368,6 @@ func (c *Client) GetStatus() (map[string]interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
return nil, fmt.Errorf("proxy error: %s", resp.Error)
|
||||
}
|
||||
|
||||
return resp.Data, nil
|
||||
}
|
||||
|
||||
@@ -380,13 +378,6 @@ func (c *Client) RegisterNodes() ([]map[string]interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
if proxyErr := classifyError(nil, resp.Error); proxyErr != nil {
|
||||
return nil, proxyErr
|
||||
}
|
||||
return nil, fmt.Errorf("proxy error: %s", resp.Error)
|
||||
}
|
||||
|
||||
// Extract nodes array from data
|
||||
nodesRaw, ok := resp.Data["nodes"]
|
||||
if !ok {
|
||||
@@ -422,18 +413,6 @@ func (c *Client) GetTemperature(nodeHost string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
if proxyErr := classifyError(nil, resp.Error); proxyErr != nil {
|
||||
return "", proxyErr
|
||||
}
|
||||
return "", &ProxyError{
|
||||
Type: ErrorTypeUnknown,
|
||||
Message: resp.Error,
|
||||
Retryable: false,
|
||||
Wrapped: fmt.Errorf("%s", resp.Error),
|
||||
}
|
||||
}
|
||||
|
||||
// Extract temperature JSON string
|
||||
tempRaw, ok := resp.Data["temperature"]
|
||||
if !ok {
|
||||
@@ -455,17 +434,9 @@ func (c *Client) RequestCleanup(host string) error {
|
||||
params["host"] = host
|
||||
}
|
||||
|
||||
resp, err := c.call("request_cleanup", params)
|
||||
if err != nil {
|
||||
if _, err := c.call("request_cleanup", params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
if resp.Error != "" {
|
||||
return fmt.Errorf("proxy error: %s", resp.Error)
|
||||
}
|
||||
return fmt.Errorf("proxy rejected cleanup request")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
488
internal/tempproxy/client_coverage_test.go
Normal file
488
internal/tempproxy/client_coverage_test.go
Normal file
@@ -0,0 +1,488 @@
|
||||
package tempproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type writeErrConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c writeErrConn) Write(p []byte) (int, error) {
|
||||
return 0, errors.New("write failed")
|
||||
}
|
||||
|
||||
func startUnixServer(t *testing.T, handler func(net.Conn)) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
socketPath := filepath.Join(dir, "proxy.sock")
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handler(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() {
|
||||
ln.Close()
|
||||
<-done
|
||||
}
|
||||
return socketPath, cleanup
|
||||
}
|
||||
|
||||
func startRPCServer(t *testing.T, handler func(RPCRequest) RPCResponse) (string, func()) {
|
||||
t.Helper()
|
||||
return startUnixServer(t, func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
decoder := json.NewDecoder(conn)
|
||||
var req RPCRequest
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return
|
||||
}
|
||||
resp := handler(req)
|
||||
encoder := json.NewEncoder(conn)
|
||||
_ = encoder.Encode(resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewClientSocketSelection(t *testing.T) {
|
||||
origStat := statFn
|
||||
t.Cleanup(func() { statFn = origStat })
|
||||
|
||||
t.Run("EnvOverride", func(t *testing.T) {
|
||||
origEnv := os.Getenv("PULSE_SENSOR_PROXY_SOCKET")
|
||||
t.Cleanup(func() {
|
||||
if origEnv == "" {
|
||||
os.Unsetenv("PULSE_SENSOR_PROXY_SOCKET")
|
||||
} else {
|
||||
os.Setenv("PULSE_SENSOR_PROXY_SOCKET", origEnv)
|
||||
}
|
||||
})
|
||||
os.Setenv("PULSE_SENSOR_PROXY_SOCKET", "/tmp/custom.sock")
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist }
|
||||
|
||||
client := NewClient()
|
||||
if client.socketPath != "/tmp/custom.sock" {
|
||||
t.Fatalf("expected env socket, got %q", client.socketPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DefaultExists", func(t *testing.T) {
|
||||
statFn = func(path string) (os.FileInfo, error) {
|
||||
if path == defaultSocketPath {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
client := NewClient()
|
||||
if client.socketPath != defaultSocketPath {
|
||||
t.Fatalf("expected default socket, got %q", client.socketPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContainerExists", func(t *testing.T) {
|
||||
statFn = func(path string) (os.FileInfo, error) {
|
||||
if path == containerSocketPath {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
client := NewClient()
|
||||
if client.socketPath != containerSocketPath {
|
||||
t.Fatalf("expected container socket, got %q", client.socketPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FallbackDefault", func(t *testing.T) {
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist }
|
||||
client := NewClient()
|
||||
if client.socketPath != defaultSocketPath {
|
||||
t.Fatalf("expected fallback socket, got %q", client.socketPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientIsAvailable(t *testing.T) {
|
||||
origStat := statFn
|
||||
t.Cleanup(func() { statFn = origStat })
|
||||
|
||||
client := &Client{socketPath: "/tmp/socket"}
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, nil }
|
||||
if !client.IsAvailable() {
|
||||
t.Fatalf("expected available")
|
||||
}
|
||||
|
||||
statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist }
|
||||
if client.IsAvailable() {
|
||||
t.Fatalf("expected unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallOnceSuccessAndDeadline(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
if req.Method != "ping" {
|
||||
t.Fatalf("unexpected method: %s", req.Method)
|
||||
}
|
||||
return RPCResponse{Success: true}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
|
||||
resp, err := client.callOnce(context.Background(), "ping", nil)
|
||||
if err != nil || resp == nil || !resp.Success {
|
||||
t.Fatalf("expected success: resp=%v err=%v", resp, err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
resp, err = client.callOnce(ctx, "ping", nil)
|
||||
if err != nil || resp == nil || !resp.Success {
|
||||
t.Fatalf("expected success with deadline: resp=%v err=%v", resp, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallOnceErrors(t *testing.T) {
|
||||
t.Run("ConnectError", func(t *testing.T) {
|
||||
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
|
||||
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected connect error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EncodeError", func(t *testing.T) {
|
||||
socketPath, cleanup := startUnixServer(t, func(conn net.Conn) {
|
||||
conn.Close()
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected encode error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EncodeErrorDial", func(t *testing.T) {
|
||||
origDial := dialContextFn
|
||||
t.Cleanup(func() { dialContextFn = origDial })
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() { _ = serverConn.Close() })
|
||||
|
||||
dialContextFn = func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
return writeErrConn{Conn: clientConn}, nil
|
||||
}
|
||||
|
||||
client := &Client{socketPath: "/tmp/unused.sock", timeout: 50 * time.Millisecond}
|
||||
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected encode error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeError", func(t *testing.T) {
|
||||
socketPath, cleanup := startUnixServer(t, func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
_, _ = conn.Write([]byte("{"))
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.callOnce(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected decode error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCallWithContextBehavior(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
resp, err := client.callWithContext(context.Background(), "ping", nil)
|
||||
if err != nil || resp == nil || !resp.Success {
|
||||
t.Fatalf("expected success: resp=%v err=%v", resp, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NonRetryable", func(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "unauthorized"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected auth error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CancelledBeforeRetry", func(t *testing.T) {
|
||||
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
if _, err := client.callWithContext(ctx, "ping", nil); err == nil {
|
||||
t.Fatalf("expected cancelled error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CancelledDuringBackoff", func(t *testing.T) {
|
||||
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if _, err := client.callWithContext(ctx, "ping", nil); err == nil {
|
||||
t.Fatalf("expected backoff cancel error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MaxRetriesExhausted", func(t *testing.T) {
|
||||
client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
|
||||
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected retry error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{"ok": true}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
data, err := client.GetStatus()
|
||||
if err != nil || data["ok"] != true {
|
||||
t.Fatalf("unexpected status: %v err=%v", data, err)
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "boom"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.GetStatus(); err == nil {
|
||||
t.Fatalf("expected status error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterNodes(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{
|
||||
"nodes": []interface{}{map[string]interface{}{"name": "node1"}},
|
||||
}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
nodes, err := client.RegisterNodes()
|
||||
if err != nil || len(nodes) != 1 || nodes[0]["name"] != "node1" {
|
||||
t.Fatalf("unexpected nodes: %v err=%v", nodes, err)
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "unauthorized"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.RegisterNodes(); err == nil {
|
||||
t.Fatalf("expected proxy error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.RegisterNodes(); err == nil {
|
||||
t.Fatalf("expected missing nodes error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{"nodes": "bad"}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.RegisterNodes(); err == nil {
|
||||
t.Fatalf("expected nodes type error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{"nodes": []interface{}{"bad"}}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.RegisterNodes(); err == nil {
|
||||
t.Fatalf("expected node map error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTemperature(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{"temperature": "42"}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
temp, err := client.GetTemperature("node1")
|
||||
if err != nil || temp != "42" {
|
||||
t.Fatalf("unexpected temp: %q err=%v", temp, err)
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "unauthorized"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.GetTemperature("node1"); err == nil {
|
||||
t.Fatalf("expected proxy error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "boom"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.GetTemperature("node1"); err == nil {
|
||||
t.Fatalf("expected unknown error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.GetTemperature("node1"); err == nil {
|
||||
t.Fatalf("expected missing temperature error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true, Data: map[string]interface{}{"temperature": 12}}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.GetTemperature("node1"); err == nil {
|
||||
t.Fatalf("expected type error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCleanup(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
if req.Method != "request_cleanup" {
|
||||
t.Fatalf("unexpected method: %s", req.Method)
|
||||
}
|
||||
if req.Params["host"] != "node1" {
|
||||
t.Fatalf("unexpected params: %v", req.Params)
|
||||
}
|
||||
return RPCResponse{Success: true}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if err := client.RequestCleanup("node1"); err != nil {
|
||||
t.Fatalf("unexpected cleanup error: %v", err)
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "boom"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if err := client.RequestCleanup(""); err == nil {
|
||||
t.Fatalf("expected proxy error")
|
||||
}
|
||||
|
||||
socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: ""}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if err := client.RequestCleanup(""); err == nil {
|
||||
t.Fatalf("expected rejected error")
|
||||
}
|
||||
|
||||
client = &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond}
|
||||
if err := client.RequestCleanup(""); err == nil {
|
||||
t.Fatalf("expected call error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallWrapper(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: true}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
resp, err := client.call("ping", nil)
|
||||
if err != nil || resp == nil || !resp.Success {
|
||||
t.Fatalf("expected call success: resp=%v err=%v", resp, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallWithContextRetryableResponseError(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: "ssh timeout"}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected retry exhaustion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallWithContextSuccessOnRetry(t *testing.T) {
|
||||
count := 0
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
count++
|
||||
if count == 1 {
|
||||
return RPCResponse{Success: false, Error: "ssh timeout"}
|
||||
}
|
||||
return RPCResponse{Success: true}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
resp, err := client.callWithContext(context.Background(), "ping", nil)
|
||||
if err != nil || resp == nil || !resp.Success {
|
||||
t.Fatalf("expected retry success: resp=%v err=%v", resp, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallWithContextRespErrorString(t *testing.T) {
|
||||
socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse {
|
||||
return RPCResponse{Success: false, Error: strings.Repeat("x", 1)}
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond}
|
||||
if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
@@ -200,6 +200,12 @@ func TestContains(t *testing.T) {
|
||||
substrs: []string{"ssh error"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive upper substr",
|
||||
s: "connection refused",
|
||||
substrs: []string{"REFUSED"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
s: "",
|
||||
|
||||
@@ -80,7 +80,11 @@ func (c *HTTPClient) GetTemperature(nodeHost string) (string, error) {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Execute request with retries
|
||||
var lastErr error
|
||||
var lastErr error = &ProxyError{
|
||||
Type: ErrorTypeUnknown,
|
||||
Message: "all retry attempts failed",
|
||||
Retryable: false,
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := calculateBackoff(attempt)
|
||||
@@ -167,15 +171,7 @@ func (c *HTTPClient) GetTemperature(nodeHost string) (string, error) {
|
||||
}
|
||||
|
||||
// All retries exhausted
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
return "", &ProxyError{
|
||||
Type: ErrorTypeUnknown,
|
||||
Message: "all retry attempts failed",
|
||||
Retryable: false,
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
// HealthCheck calls the proxy /health endpoint to verify connectivity.
|
||||
|
||||
@@ -2,6 +2,8 @@ package tempproxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -360,6 +362,55 @@ func TestHTTPClient_GetTemperature_RetryOnTransportError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient_GetTemperature_RequestBuildError(t *testing.T) {
|
||||
client := NewHTTPClient("http://[::1", "token")
|
||||
_, err := client.GetTemperature("node1")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for request creation")
|
||||
}
|
||||
|
||||
proxyErr, ok := err.(*ProxyError)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *ProxyError, got %T", err)
|
||||
}
|
||||
if proxyErr.Type != ErrorTypeTransport {
|
||||
t.Errorf("Type = %v, want ErrorTypeTransport", proxyErr.Type)
|
||||
}
|
||||
if !strings.Contains(proxyErr.Message, "failed to create HTTP request") {
|
||||
t.Errorf("Message = %q, want request creation failure", proxyErr.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient_GetTemperature_ReadBodyError(t *testing.T) {
|
||||
attempts := 0
|
||||
client := NewHTTPClient("https://example.com", "token")
|
||||
client.httpClient.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: errReadCloser{},
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
body := `{"node":"node1","temperature":"{}"}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
_, err := client.GetTemperature("node1")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success after retry, got %v", err)
|
||||
}
|
||||
if attempts < 2 {
|
||||
t.Fatalf("Expected retry after read error, got %d attempts", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient_HealthCheck_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/health" {
|
||||
@@ -398,6 +449,26 @@ func TestHTTPClient_HealthCheck_NotConfigured(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient_HealthCheck_RequestBuildError(t *testing.T) {
|
||||
client := NewHTTPClient("http://[::1", "token")
|
||||
err := client.HealthCheck()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for request creation")
|
||||
}
|
||||
|
||||
proxyErr, ok := err.(*ProxyError)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *ProxyError, got %T", err)
|
||||
}
|
||||
if proxyErr.Type != ErrorTypeTransport {
|
||||
t.Errorf("Type = %v, want ErrorTypeTransport", proxyErr.Type)
|
||||
}
|
||||
if !strings.Contains(proxyErr.Message, "failed to create HTTP request") {
|
||||
t.Errorf("Message = %q, want request creation failure", proxyErr.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient_HealthCheck_ServerError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
@@ -493,3 +564,19 @@ func TestHTTPClient_Fields(t *testing.T) {
|
||||
t.Errorf("timeout = %v, want 60s", client.timeout)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
type errReadCloser struct{}
|
||||
|
||||
func (errReadCloser) Read(p []byte) (int, error) {
|
||||
return 0, errors.New("read failed")
|
||||
}
|
||||
|
||||
func (errReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
326
internal/updatedetection/manager_coverage_test.go
Normal file
326
internal/updatedetection/manager_coverage_test.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package updatedetection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestDefaultManagerConfig(t *testing.T) {
|
||||
cfg := DefaultManagerConfig()
|
||||
if !cfg.Enabled {
|
||||
t.Fatalf("expected enabled by default")
|
||||
}
|
||||
if cfg.CheckInterval != 6*time.Hour {
|
||||
t.Fatalf("expected check interval 6h, got %v", cfg.CheckInterval)
|
||||
}
|
||||
if cfg.AlertDelayHours != 24 {
|
||||
t.Fatalf("expected alert delay 24, got %d", cfg.AlertDelayHours)
|
||||
}
|
||||
if !cfg.EnableDockerUpdates {
|
||||
t.Fatalf("expected docker updates enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerConfigAccessors(t *testing.T) {
|
||||
cfg := ManagerConfig{
|
||||
Enabled: false,
|
||||
CheckInterval: time.Minute,
|
||||
AlertDelayHours: 12,
|
||||
EnableDockerUpdates: false,
|
||||
}
|
||||
mgr := NewManager(cfg, zerolog.Nop())
|
||||
if mgr.Enabled() {
|
||||
t.Fatalf("expected disabled manager")
|
||||
}
|
||||
mgr.SetEnabled(true)
|
||||
if !mgr.Enabled() {
|
||||
t.Fatalf("expected enabled manager after SetEnabled")
|
||||
}
|
||||
if mgr.AlertDelayHours() != 12 {
|
||||
t.Fatalf("expected alert delay 12, got %d", mgr.AlertDelayHours())
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerProcessDockerContainerUpdate(t *testing.T) {
|
||||
now := time.Now()
|
||||
status := &ContainerUpdateStatus{
|
||||
UpdateAvailable: true,
|
||||
CurrentDigest: "sha256:old",
|
||||
LatestDigest: "sha256:new",
|
||||
LastChecked: now,
|
||||
Error: "boom",
|
||||
}
|
||||
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
cfg := DefaultManagerConfig()
|
||||
cfg.Enabled = false
|
||||
mgr := NewManager(cfg, zerolog.Nop())
|
||||
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status)
|
||||
if mgr.store.Count() != 0 {
|
||||
t.Fatalf("expected no updates when disabled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerUpdatesDisabled", func(t *testing.T) {
|
||||
cfg := DefaultManagerConfig()
|
||||
cfg.EnableDockerUpdates = false
|
||||
mgr := NewManager(cfg, zerolog.Nop())
|
||||
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status)
|
||||
if mgr.store.Count() != 0 {
|
||||
t.Fatalf("expected no updates when docker updates disabled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NilStatus", func(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", nil)
|
||||
if mgr.store.Count() != 0 {
|
||||
t.Fatalf("expected no updates for nil status")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoUpdateDeletes", func(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
mgr.store.UpsertUpdate(&UpdateInfo{
|
||||
ID: "docker:host-1:container-1",
|
||||
ResourceID: "container-1",
|
||||
HostID: "host-1",
|
||||
})
|
||||
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", &ContainerUpdateStatus{
|
||||
UpdateAvailable: false,
|
||||
LastChecked: now,
|
||||
})
|
||||
if mgr.store.Count() != 0 {
|
||||
t.Fatalf("expected update to be deleted when no update available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateAvailable", func(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status)
|
||||
update := mgr.store.GetUpdatesForResource("container-1")
|
||||
if update == nil {
|
||||
t.Fatalf("expected update to be stored")
|
||||
}
|
||||
if update.ID != "docker:host-1:container-1" {
|
||||
t.Fatalf("unexpected update ID %q", update.ID)
|
||||
}
|
||||
if update.Error != "boom" {
|
||||
t.Fatalf("expected error to be stored")
|
||||
}
|
||||
if update.CurrentVersion != "nginx:latest" {
|
||||
t.Fatalf("expected current version to be stored")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerCheckImageUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
cfg := DefaultManagerConfig()
|
||||
cfg.Enabled = false
|
||||
mgr := NewManager(cfg, zerolog.Nop())
|
||||
info, err := mgr.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info != nil {
|
||||
t.Fatalf("expected nil info when disabled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cached", func(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
cacheKey := "registry-1.docker.io/library/nginx:latest"
|
||||
mgr.registry.cacheDigest(cacheKey, "sha256:new")
|
||||
|
||||
info, err := mgr.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info == nil || !info.UpdateAvailable {
|
||||
t.Fatalf("expected update available from cache")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerGetUpdatesAndAccessors(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
now := time.Now()
|
||||
|
||||
update1 := &UpdateInfo{
|
||||
ID: "update-1",
|
||||
ResourceID: "res-1",
|
||||
ResourceType: "docker",
|
||||
ResourceName: "nginx",
|
||||
HostID: "host-1",
|
||||
Type: UpdateTypeDockerImage,
|
||||
Severity: SeveritySecurity,
|
||||
LastChecked: now,
|
||||
}
|
||||
update2 := &UpdateInfo{
|
||||
ID: "update-2",
|
||||
ResourceID: "res-2",
|
||||
ResourceType: "vm",
|
||||
ResourceName: "vm-1",
|
||||
HostID: "host-2",
|
||||
Type: UpdateTypePackage,
|
||||
LastChecked: now.Add(-time.Minute),
|
||||
}
|
||||
mgr.store.UpsertUpdate(update1)
|
||||
mgr.store.UpsertUpdate(update2)
|
||||
|
||||
if mgr.GetTotalCount() != 2 {
|
||||
t.Fatalf("expected total count 2, got %d", mgr.GetTotalCount())
|
||||
}
|
||||
|
||||
all := mgr.GetUpdates(UpdateFilters{})
|
||||
if len(all) != 2 {
|
||||
t.Fatalf("expected 2 updates, got %d", len(all))
|
||||
}
|
||||
|
||||
filtered := mgr.GetUpdates(UpdateFilters{HostID: "host-1"})
|
||||
if len(filtered) != 1 || filtered[0].ID != "update-1" {
|
||||
t.Fatalf("expected one update for host-1")
|
||||
}
|
||||
|
||||
if len(mgr.GetUpdatesForHost("host-1")) != 1 {
|
||||
t.Fatalf("expected 1 update for host-1")
|
||||
}
|
||||
if mgr.GetUpdatesForResource("res-2") == nil {
|
||||
t.Fatalf("expected update for resource res-2")
|
||||
}
|
||||
|
||||
summary := mgr.GetSummary()
|
||||
if summary["host-1"].TotalUpdates != 1 {
|
||||
t.Fatalf("expected summary for host-1")
|
||||
}
|
||||
|
||||
mgr.AddRegistryConfig(RegistryConfig{Host: "example.com", Username: "user"})
|
||||
if _, ok := mgr.registry.configs["example.com"]; !ok {
|
||||
t.Fatalf("expected registry config to be stored")
|
||||
}
|
||||
|
||||
mgr.DeleteUpdatesForHost("host-1")
|
||||
if len(mgr.GetUpdatesForHost("host-1")) != 0 {
|
||||
t.Fatalf("expected updates removed for host-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCleanupStale(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
now := time.Now()
|
||||
|
||||
mgr.store.UpsertUpdate(&UpdateInfo{
|
||||
ID: "stale",
|
||||
ResourceID: "res-1",
|
||||
HostID: "host-1",
|
||||
LastChecked: now.Add(-2 * time.Hour),
|
||||
})
|
||||
mgr.store.UpsertUpdate(&UpdateInfo{
|
||||
ID: "fresh",
|
||||
ResourceID: "res-2",
|
||||
HostID: "host-1",
|
||||
LastChecked: now,
|
||||
})
|
||||
|
||||
removed := mgr.CleanupStale(time.Hour)
|
||||
if removed != 1 {
|
||||
t.Fatalf("expected 1 stale update removed, got %d", removed)
|
||||
}
|
||||
if mgr.GetTotalCount() != 1 {
|
||||
t.Fatalf("expected one update remaining")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGetUpdatesReadyForAlert(t *testing.T) {
|
||||
mgr := NewManager(DefaultManagerConfig(), zerolog.Nop())
|
||||
mgr.alertDelayHours = 1
|
||||
now := time.Now()
|
||||
|
||||
mgr.store.UpsertUpdate(&UpdateInfo{
|
||||
ID: "ready",
|
||||
ResourceID: "res-1",
|
||||
HostID: "host-1",
|
||||
FirstDetected: now.Add(-2 * time.Hour),
|
||||
LastChecked: now,
|
||||
})
|
||||
mgr.store.UpsertUpdate(&UpdateInfo{
|
||||
ID: "recent",
|
||||
ResourceID: "res-2",
|
||||
HostID: "host-1",
|
||||
FirstDetected: now.Add(-30 * time.Minute),
|
||||
LastChecked: now,
|
||||
})
|
||||
mgr.store.UpsertUpdate(&UpdateInfo{
|
||||
ID: "error",
|
||||
ResourceID: "res-3",
|
||||
HostID: "host-1",
|
||||
FirstDetected: now.Add(-2 * time.Hour),
|
||||
LastChecked: now,
|
||||
Error: "rate limited",
|
||||
})
|
||||
|
||||
ready := mgr.GetUpdatesReadyForAlert()
|
||||
if len(ready) != 1 || ready[0].ID != "ready" {
|
||||
t.Fatalf("expected only ready update, got %d", len(ready))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFilters(t *testing.T) {
|
||||
update := &UpdateInfo{
|
||||
HostID: "host-1",
|
||||
ResourceType: "docker",
|
||||
Type: UpdateTypeDockerImage,
|
||||
Severity: SeveritySecurity,
|
||||
Error: "",
|
||||
}
|
||||
|
||||
if !(&UpdateFilters{}).IsEmpty() {
|
||||
t.Fatalf("expected empty filters to be empty")
|
||||
}
|
||||
if (&UpdateFilters{HostID: "host-1"}).IsEmpty() {
|
||||
t.Fatalf("expected non-empty filters")
|
||||
}
|
||||
|
||||
if (&UpdateFilters{HostID: "other"}).Matches(update) {
|
||||
t.Fatalf("expected host mismatch to fail")
|
||||
}
|
||||
if (&UpdateFilters{ResourceType: "vm"}).Matches(update) {
|
||||
t.Fatalf("expected resource type mismatch to fail")
|
||||
}
|
||||
if (&UpdateFilters{UpdateType: UpdateTypePackage}).Matches(update) {
|
||||
t.Fatalf("expected update type mismatch to fail")
|
||||
}
|
||||
if (&UpdateFilters{Severity: SeverityBugfix}).Matches(update) {
|
||||
t.Fatalf("expected severity mismatch to fail")
|
||||
}
|
||||
|
||||
hasError := true
|
||||
if (&UpdateFilters{HasError: &hasError}).Matches(update) {
|
||||
t.Fatalf("expected error filter to fail")
|
||||
}
|
||||
|
||||
update.Error = "boom"
|
||||
hasError = false
|
||||
if (&UpdateFilters{HasError: &hasError}).Matches(update) {
|
||||
t.Fatalf("expected error=false filter to fail")
|
||||
}
|
||||
|
||||
update.Error = ""
|
||||
hasError = false
|
||||
filters := UpdateFilters{
|
||||
HostID: "host-1",
|
||||
ResourceType: "docker",
|
||||
UpdateType: UpdateTypeDockerImage,
|
||||
Severity: SeveritySecurity,
|
||||
HasError: &hasError,
|
||||
}
|
||||
if !filters.Matches(update) {
|
||||
t.Fatalf("expected filters to match update")
|
||||
}
|
||||
}
|
||||
465
internal/updatedetection/registry_coverage_test.go
Normal file
465
internal/updatedetection/registry_coverage_test.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package updatedetection
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
type errorReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *errorReader) Read(p []byte) (int, error) {
|
||||
return 0, e.err
|
||||
}
|
||||
|
||||
func newResponse(status int, headers http.Header, body io.Reader) *http.Response {
|
||||
if headers == nil {
|
||||
headers = http.Header{}
|
||||
}
|
||||
if body == nil {
|
||||
body = bytes.NewReader(nil)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Status: http.StatusText(status),
|
||||
Header: headers,
|
||||
Body: io.NopCloser(body),
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryCheckerConfig(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
if r.httpClient == nil || r.cache == nil || r.configs == nil {
|
||||
t.Fatalf("expected registry checker to initialize")
|
||||
}
|
||||
|
||||
cfg := RegistryConfig{Host: "example.com", Username: "user", Password: "pass", Insecure: true}
|
||||
r.AddRegistryConfig(cfg)
|
||||
|
||||
r.mu.RLock()
|
||||
stored := r.configs["example.com"]
|
||||
r.mu.RUnlock()
|
||||
if stored.Username != "user" || !stored.Insecure {
|
||||
t.Fatalf("expected config to be stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryCache(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.cacheDigest("key-digest", "sha256:abc")
|
||||
r.cacheError("key-error", "boom")
|
||||
|
||||
if r.CacheSize() != 2 {
|
||||
t.Fatalf("expected cache size 2, got %d", r.CacheSize())
|
||||
}
|
||||
if entry := r.getCached("key-digest"); entry == nil || entry.digest != "sha256:abc" {
|
||||
t.Fatalf("expected cached digest")
|
||||
}
|
||||
if entry := r.getCached("missing"); entry != nil {
|
||||
t.Fatalf("expected missing cache entry to be nil")
|
||||
}
|
||||
|
||||
r.cache.entries["expired"] = cacheEntry{
|
||||
digest: "old",
|
||||
expiresAt: time.Now().Add(-time.Minute),
|
||||
}
|
||||
if entry := r.getCached("expired"); entry != nil {
|
||||
t.Fatalf("expected expired cache entry to be nil")
|
||||
}
|
||||
|
||||
r.CleanupCache()
|
||||
if r.CacheSize() != 2 {
|
||||
t.Fatalf("expected expired cache entry to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryCheckImageUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DigestPinned", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
info, err := r.CheckImageUpdate(ctx, "nginx@sha256:abc", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Error == "" {
|
||||
t.Fatalf("expected digest-pinned error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CachedError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
key := "registry-1.docker.io/library/nginx:latest"
|
||||
r.cache.entries[key] = cacheEntry{
|
||||
err: "cached error",
|
||||
expiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
info, err := r.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Error != "cached error" {
|
||||
t.Fatalf("expected cached error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CachedDigest", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
key := "registry-1.docker.io/library/nginx:latest"
|
||||
r.cache.entries[key] = cacheEntry{
|
||||
digest: "sha256:new",
|
||||
expiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
info, err := r.CheckImageUpdate(ctx, "nginx:latest", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !info.UpdateAvailable || info.LatestDigest != "sha256:new" {
|
||||
t.Fatalf("expected update available from cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FetchErrorCaches", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusInternalServerError, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
info, err := r.CheckImageUpdate(ctx, "example.com/repo:tag", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Error == "" {
|
||||
t.Fatalf("expected error from fetch")
|
||||
}
|
||||
if cached := r.getCached("example.com/repo:tag"); cached == nil || cached.err == "" {
|
||||
t.Fatalf("expected error to be cached")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FetchSuccessCaches", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
headers := http.Header{"Docker-Content-Digest": []string{"sha256:new"}}
|
||||
return newResponse(http.StatusOK, headers, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
info, err := r.CheckImageUpdate(ctx, "example.com/repo:tag", "sha256:old")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !info.UpdateAvailable {
|
||||
t.Fatalf("expected update available")
|
||||
}
|
||||
if cached := r.getCached("example.com/repo:tag"); cached == nil || cached.digest == "" {
|
||||
t.Fatalf("expected digest to be cached")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFetchDigest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("AuthError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("auth fail")
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "registry-1.docker.io", "library/nginx", "latest"); err == nil {
|
||||
t.Fatalf("expected auth error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsecureScheme", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.AddRegistryConfig(RegistryConfig{Host: "insecure.local", Insecure: true})
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme != "http" {
|
||||
t.Fatalf("expected http scheme, got %q", req.URL.Scheme)
|
||||
}
|
||||
headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}}
|
||||
return newResponse(http.StatusOK, headers, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "insecure.local", "repo", "latest"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RequestError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
if _, err := r.fetchDigest(ctx, "bad host", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected request creation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TokenHeader", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.AddRegistryConfig(RegistryConfig{Host: "ghcr.io", Password: "pat"})
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer pat" {
|
||||
t.Fatalf("expected bearer token, got %q", got)
|
||||
}
|
||||
headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}}
|
||||
return newResponse(http.StatusOK, headers, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "ghcr.io", "owner/repo", "latest"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DoError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network")
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected request error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("StatusUnauthorized", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusUnauthorized, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected unauthorized error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("StatusNotFound", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusNotFound, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected not found error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("StatusRateLimited", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusTooManyRequests, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected rate limit error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("StatusOtherError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusInternalServerError, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected registry error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DigestHeader", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}}
|
||||
return newResponse(http.StatusOK, headers, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
digest, err := r.fetchDigest(ctx, "example.com", "repo", "latest")
|
||||
if err != nil || digest != "sha256:abc" {
|
||||
t.Fatalf("expected digest, got %q err %v", digest, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EtagHeader", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
headers := http.Header{"Etag": []string{`"sha256:etag"`}}
|
||||
return newResponse(http.StatusOK, headers, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
digest, err := r.fetchDigest(ctx, "example.com", "repo", "latest")
|
||||
if err != nil || digest != "sha256:etag" {
|
||||
t.Fatalf("expected etag digest, got %q err %v", digest, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoDigest", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusOK, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil {
|
||||
t.Fatalf("expected missing digest error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DockerHubSuccess", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.AddRegistryConfig(RegistryConfig{Host: "registry-1.docker.io", Username: "user", Password: "pass"})
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if !strings.HasPrefix(req.Header.Get("Authorization"), "Basic ") {
|
||||
t.Fatalf("expected basic auth header")
|
||||
}
|
||||
body := `{"token":"tok"}`
|
||||
return newResponse(http.StatusOK, nil, strings.NewReader(body)), nil
|
||||
}),
|
||||
}
|
||||
|
||||
token, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx")
|
||||
if err != nil || token != "tok" {
|
||||
t.Fatalf("expected token, got %q err %v", token, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerHubRequestError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "bad\nrepo"); err == nil {
|
||||
t.Fatalf("expected request error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerHubDoError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network")
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
|
||||
t.Fatalf("expected network error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerHubStatusError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return newResponse(http.StatusInternalServerError, nil, nil), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
|
||||
t.Fatalf("expected status error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerHubReadError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
body := &errorReader{err: errors.New("read")}
|
||||
return newResponse(http.StatusOK, nil, body), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
|
||||
t.Fatalf("expected read error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DockerHubJSONError", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
body := "{"
|
||||
return newResponse(http.StatusOK, nil, strings.NewReader(body)), nil
|
||||
}),
|
||||
}
|
||||
|
||||
if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil {
|
||||
t.Fatalf("expected json error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GHCRToken", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.AddRegistryConfig(RegistryConfig{Host: "ghcr.io", Password: "pat"})
|
||||
token, err := r.getAuthToken(ctx, "ghcr.io", "owner/repo")
|
||||
if err != nil || token != "pat" {
|
||||
t.Fatalf("expected ghcr token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GHCRAnonymous", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
token, err := r.getAuthToken(ctx, "ghcr.io", "owner/repo")
|
||||
if err != nil || token != "" {
|
||||
t.Fatalf("expected empty ghcr token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OtherRegistryBasicAuth", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
r.AddRegistryConfig(RegistryConfig{Host: "example.com", Username: "user", Password: "pass"})
|
||||
token, err := r.getAuthToken(ctx, "example.com", "repo")
|
||||
if err != nil || token != "" {
|
||||
t.Fatalf("expected empty token for basic auth registry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OtherRegistryNoAuth", func(t *testing.T) {
|
||||
r := NewRegistryChecker(zerolog.Nop())
|
||||
token, err := r.getAuthToken(ctx, "example.com", "repo")
|
||||
if err != nil || token != "" {
|
||||
t.Fatalf("expected empty token for registry without auth")
|
||||
}
|
||||
})
|
||||
}
|
||||
43
internal/updatedetection/store_coverage_test.go
Normal file
43
internal/updatedetection/store_coverage_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package updatedetection
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestStore_DeleteUpdateNotFound(t *testing.T) {
|
||||
store := NewStore()
|
||||
store.DeleteUpdate("missing")
|
||||
if store.Count() != 0 {
|
||||
t.Fatalf("expected empty store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteUpdatesForResourceMissing(t *testing.T) {
|
||||
store := NewStore()
|
||||
store.DeleteUpdatesForResource("missing")
|
||||
if store.Count() != 0 {
|
||||
t.Fatalf("expected empty store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteUpdatesForResourceNilUpdate(t *testing.T) {
|
||||
store := NewStore()
|
||||
store.byResource["res-1"] = "update-1"
|
||||
store.updates["update-1"] = nil
|
||||
|
||||
store.DeleteUpdatesForResource("res-1")
|
||||
if _, ok := store.byResource["res-1"]; ok {
|
||||
t.Fatalf("expected byResource entry to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CountForHost(t *testing.T) {
|
||||
store := NewStore()
|
||||
store.UpsertUpdate(&UpdateInfo{ID: "update-1", ResourceID: "res-1", HostID: "host-1"})
|
||||
store.UpsertUpdate(&UpdateInfo{ID: "update-2", ResourceID: "res-2", HostID: "host-1"})
|
||||
|
||||
if store.CountForHost("host-1") != 2 {
|
||||
t.Fatalf("expected count 2 for host-1")
|
||||
}
|
||||
if store.CountForHost("missing") != 0 {
|
||||
t.Fatalf("expected count 0 for missing host")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user