diff --git a/internal/agentbinaries/host_agent.go b/internal/agentbinaries/host_agent.go index b51d73543..e75b43086 100644 --- a/internal/agentbinaries/host_agent.go +++ b/internal/agentbinaries/host_agent.go @@ -51,6 +51,25 @@ var requiredHostAgentBinaries = []HostAgentBinary{ var downloadMu sync.Mutex +var ( + httpClient = &http.Client{Timeout: 2 * time.Minute} + downloadURLForVersion = func(version string) string { + return fmt.Sprintf("https://github.com/rcourtman/Pulse/releases/download/%[1]s/pulse-%[1]s.tar.gz", version) + } + downloadAndInstallHostAgentBinariesFn = DownloadAndInstallHostAgentBinaries + findMissingHostAgentBinariesFn = findMissingHostAgentBinaries + mkdirAllFn = os.MkdirAll + createTempFn = os.CreateTemp + removeFn = os.Remove + openFileFn = os.Open + openFileModeFn = os.OpenFile + renameFn = os.Rename + symlinkFn = os.Symlink + copyFn = io.Copy + chmodFileFn = func(f *os.File, mode os.FileMode) error { return f.Chmod(mode) } + closeFileFn = func(f *os.File) error { return f.Close() } +) + // HostAgentSearchPaths returns the directories to search for host agent binaries. func HostAgentSearchPaths() []string { primary := strings.TrimSpace(os.Getenv("PULSE_BIN_DIR")) @@ -77,7 +96,7 @@ func HostAgentSearchPaths() []string { // The returned map contains any binaries that remain missing after the attempt. func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary { binDirs := HostAgentSearchPaths() - missing := findMissingHostAgentBinaries(binDirs) + missing := findMissingHostAgentBinariesFn(binDirs) if len(missing) == 0 { return nil } @@ -86,7 +105,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary { defer downloadMu.Unlock() // Re-check after acquiring the lock in case another goroutine restored them. - missing = findMissingHostAgentBinaries(binDirs) + missing = findMissingHostAgentBinariesFn(binDirs) if len(missing) == 0 { return nil } @@ -101,7 +120,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary { Strs("missing_platforms", missingPlatforms). Msg("Host agent binaries missing - attempting to download bundle from GitHub release") - if err := DownloadAndInstallHostAgentBinaries(version, binDirs[0]); err != nil { + if err := downloadAndInstallHostAgentBinariesFn(version, binDirs[0]); err != nil { log.Error(). Err(err). Str("target_dir", binDirs[0]). @@ -110,7 +129,7 @@ func EnsureHostAgentBinaries(version string) map[string]HostAgentBinary { return missing } - if remaining := findMissingHostAgentBinaries(binDirs); len(remaining) > 0 { + if remaining := findMissingHostAgentBinariesFn(binDirs); len(remaining) > 0 { stillMissing := make([]string, 0, len(remaining)) for key := range remaining { stillMissing = append(stillMissing, key) @@ -133,19 +152,18 @@ func DownloadAndInstallHostAgentBinaries(version string, targetDir string) error return fmt.Errorf("cannot download host agent bundle for non-release version %q", version) } - if err := os.MkdirAll(targetDir, 0o755); err != nil { + if err := mkdirAllFn(targetDir, 0o755); err != nil { return fmt.Errorf("failed to ensure bin directory %s: %w", targetDir, err) } - url := fmt.Sprintf("https://github.com/rcourtman/Pulse/releases/download/%[1]s/pulse-%[1]s.tar.gz", normalizedVersion) - tempFile, err := os.CreateTemp("", "pulse-host-agent-*.tar.gz") + url := downloadURLForVersion(normalizedVersion) + tempFile, err := createTempFn("", "pulse-host-agent-*.tar.gz") if err != nil { return fmt.Errorf("failed to create temporary archive file: %w", err) } - defer os.Remove(tempFile.Name()) + defer removeFn(tempFile.Name()) - client := &http.Client{Timeout: 2 * time.Minute} - resp, err := client.Get(url) + resp, err := httpClient.Get(url) if err != nil { return fmt.Errorf("failed to download host agent bundle from %s: %w", url, err) } @@ -160,7 +178,7 @@ func DownloadAndInstallHostAgentBinaries(version string, targetDir string) error return fmt.Errorf("failed to save host agent bundle: %w", err) } - if err := tempFile.Close(); err != nil { + if err := closeFileFn(tempFile); err != nil { return fmt.Errorf("failed to close temporary bundle file: %w", err) } @@ -204,7 +222,7 @@ func normalizeVersionTag(version string) string { } func extractHostAgentBinaries(archivePath, targetDir string) error { - file, err := os.Open(archivePath) + file, err := openFileFn(archivePath) if err != nil { return fmt.Errorf("failed to open host agent bundle: %w", err) } @@ -232,10 +250,6 @@ func extractHostAgentBinaries(archivePath, targetDir string) error { return fmt.Errorf("failed to read host agent bundle: %w", err) } - if header == nil { - continue - } - if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeSymlink { continue } @@ -265,10 +279,10 @@ func extractHostAgentBinaries(archivePath, targetDir string) error { } for _, link := range symlinks { - if err := os.Remove(link.path); err != nil && !os.IsNotExist(err) { + if err := removeFn(link.path); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to replace existing symlink %s: %w", link.path, err) } - if err := os.Symlink(link.target, link.path); err != nil { + if err := symlinkFn(link.target, link.path); err != nil { // Fallback: copy the referenced file if symlinks are not permitted source := filepath.Join(targetDir, link.target) if err := copyHostAgentFile(source, link.path); err != nil { @@ -281,31 +295,31 @@ func extractHostAgentBinaries(archivePath, targetDir string) error { } func writeHostAgentFile(destination string, reader io.Reader, mode os.FileMode) error { - if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil { + if err := mkdirAllFn(filepath.Dir(destination), 0o755); err != nil { return fmt.Errorf("failed to create directory for %s: %w", destination, err) } - tmpFile, err := os.CreateTemp(filepath.Dir(destination), "pulse-host-agent-*") + tmpFile, err := createTempFn(filepath.Dir(destination), "pulse-host-agent-*") if err != nil { return fmt.Errorf("failed to create temporary file for %s: %w", destination, err) } - defer os.Remove(tmpFile.Name()) + defer removeFn(tmpFile.Name()) - if _, err := io.Copy(tmpFile, reader); err != nil { - tmpFile.Close() + if _, err := copyFn(tmpFile, reader); err != nil { + closeFileFn(tmpFile) return fmt.Errorf("failed to extract %s: %w", destination, err) } - if err := tmpFile.Chmod(normalizeExecutableMode(mode)); err != nil { - tmpFile.Close() + if err := chmodFileFn(tmpFile, normalizeExecutableMode(mode)); err != nil { + closeFileFn(tmpFile) return fmt.Errorf("failed to set permissions on %s: %w", destination, err) } - if err := tmpFile.Close(); err != nil { + if err := closeFileFn(tmpFile); err != nil { return fmt.Errorf("failed to finalize %s: %w", destination, err) } - if err := os.Rename(tmpFile.Name(), destination); err != nil { + if err := renameFn(tmpFile.Name(), destination); err != nil { return fmt.Errorf("failed to install %s: %w", destination, err) } @@ -313,23 +327,23 @@ func writeHostAgentFile(destination string, reader io.Reader, mode os.FileMode) } func copyHostAgentFile(source, destination string) error { - src, err := os.Open(source) + src, err := openFileFn(source) if err != nil { return fmt.Errorf("failed to open %s for fallback copy: %w", source, err) } defer src.Close() - if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil { + if err := mkdirAllFn(filepath.Dir(destination), 0o755); err != nil { return fmt.Errorf("failed to prepare directory for %s: %w", destination, err) } - dst, err := os.OpenFile(destination, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + dst, err := openFileModeFn(destination, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) if err != nil { return fmt.Errorf("failed to create fallback copy %s: %w", destination, err) } defer dst.Close() - if _, err := io.Copy(dst, src); err != nil { + if _, err := copyFn(dst, src); err != nil { return fmt.Errorf("failed to copy %s to %s: %w", source, destination, err) } diff --git a/internal/agentbinaries/host_agent_coverage_test.go b/internal/agentbinaries/host_agent_coverage_test.go new file mode 100644 index 000000000..16d1fdc05 --- /dev/null +++ b/internal/agentbinaries/host_agent_coverage_test.go @@ -0,0 +1,698 @@ +package agentbinaries + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type tarEntry struct { + name string + body []byte + mode int64 + typeflag byte + linkname string +} + +type errorReader struct{} + +func (e *errorReader) Read([]byte) (int, error) { + return 0, errors.New("read fail") +} + +func buildTarGz(t *testing.T, entries []tarEntry) []byte { + t.Helper() + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + for _, entry := range entries { + typeflag := entry.typeflag + if typeflag == 0 { + typeflag = tar.TypeReg + } + size := int64(len(entry.body)) + if typeflag != tar.TypeReg { + size = 0 + } + hdr := &tar.Header{ + Name: entry.name, + Mode: entry.mode, + Size: size, + Typeflag: typeflag, + Linkname: entry.linkname, + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader: %v", err) + } + if typeflag == tar.TypeReg { + if _, err := tw.Write(entry.body); err != nil { + t.Fatalf("Write: %v", err) + } + } + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() +} + +func saveHostAgentHooks() func() { + origRequired := requiredHostAgentBinaries + origDownloadFn := downloadAndInstallHostAgentBinariesFn + origFindMissing := findMissingHostAgentBinariesFn + origURL := downloadURLForVersion + origClient := httpClient + origMkdirAll := mkdirAllFn + origCreateTemp := createTempFn + origRemove := removeFn + origOpen := openFileFn + origOpenMode := openFileModeFn + origRename := renameFn + origSymlink := symlinkFn + origCopy := copyFn + origChmod := chmodFileFn + origClose := closeFileFn + + return func() { + requiredHostAgentBinaries = origRequired + downloadAndInstallHostAgentBinariesFn = origDownloadFn + findMissingHostAgentBinariesFn = origFindMissing + downloadURLForVersion = origURL + httpClient = origClient + mkdirAllFn = origMkdirAll + createTempFn = origCreateTemp + removeFn = origRemove + openFileFn = origOpen + openFileModeFn = origOpenMode + renameFn = origRename + symlinkFn = origSymlink + copyFn = origCopy + chmodFileFn = origChmod + closeFileFn = origClose + } +} + +func TestEnsureHostAgentBinaries_NoMissing(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + requiredHostAgentBinaries = []HostAgentBinary{ + {Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}}, + } + + if err := os.WriteFile(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64"), []byte("bin"), 0o755); err != nil { + t.Fatalf("write file: %v", err) + } + + origEnv := os.Getenv("PULSE_BIN_DIR") + t.Cleanup(func() { + if origEnv == "" { + os.Unsetenv("PULSE_BIN_DIR") + } else { + os.Setenv("PULSE_BIN_DIR", origEnv) + } + }) + os.Setenv("PULSE_BIN_DIR", tmpDir) + + if missing := EnsureHostAgentBinaries("v1.0.0"); missing != nil { + t.Fatalf("expected no missing binaries, got %v", missing) + } +} + +func TestEnsureHostAgentBinaries_DownloadError(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + requiredHostAgentBinaries = []HostAgentBinary{ + {Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}}, + } + downloadAndInstallHostAgentBinariesFn = func(string, string) error { + return errors.New("download failed") + } + + origEnv := os.Getenv("PULSE_BIN_DIR") + t.Cleanup(func() { + if origEnv == "" { + os.Unsetenv("PULSE_BIN_DIR") + } else { + os.Setenv("PULSE_BIN_DIR", origEnv) + } + }) + os.Setenv("PULSE_BIN_DIR", tmpDir) + + missing := EnsureHostAgentBinaries("v1.0.0") + if len(missing) != 1 { + t.Fatalf("expected missing map, got %v", missing) + } +} + +func TestEnsureHostAgentBinaries_StillMissing(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + requiredHostAgentBinaries = []HostAgentBinary{ + {Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}}, + } + downloadAndInstallHostAgentBinariesFn = func(string, string) error { return nil } + + origEnv := os.Getenv("PULSE_BIN_DIR") + t.Cleanup(func() { + if origEnv == "" { + os.Unsetenv("PULSE_BIN_DIR") + } else { + os.Setenv("PULSE_BIN_DIR", origEnv) + } + }) + os.Setenv("PULSE_BIN_DIR", tmpDir) + + missing := EnsureHostAgentBinaries("v1.0.0") + if len(missing) != 1 { + t.Fatalf("expected still missing") + } +} + +func TestEnsureHostAgentBinaries_RecheckAfterLock(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + requiredHostAgentBinaries = []HostAgentBinary{ + {Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}}, + } + + calls := 0 + findMissingHostAgentBinariesFn = func([]string) map[string]HostAgentBinary { + calls++ + if calls == 1 { + return map[string]HostAgentBinary{ + "linux-amd64": requiredHostAgentBinaries[0], + } + } + return nil + } + + origEnv := os.Getenv("PULSE_BIN_DIR") + t.Cleanup(func() { + if origEnv == "" { + os.Unsetenv("PULSE_BIN_DIR") + } else { + os.Setenv("PULSE_BIN_DIR", origEnv) + } + }) + os.Setenv("PULSE_BIN_DIR", t.TempDir()) + + if result := EnsureHostAgentBinaries("v1.0.0"); result != nil { + t.Fatalf("expected nil after recheck, got %v", result) + } +} + +func TestEnsureHostAgentBinaries_RestoreSuccess(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + requiredHostAgentBinaries = []HostAgentBinary{ + {Platform: "linux", Arch: "amd64", Filenames: []string{"pulse-host-agent-linux-amd64"}}, + } + downloadAndInstallHostAgentBinariesFn = func(string, string) error { + return os.WriteFile(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64"), []byte("bin"), 0o755) + } + + origEnv := os.Getenv("PULSE_BIN_DIR") + t.Cleanup(func() { + if origEnv == "" { + os.Unsetenv("PULSE_BIN_DIR") + } else { + os.Setenv("PULSE_BIN_DIR", origEnv) + } + }) + os.Setenv("PULSE_BIN_DIR", tmpDir) + + if missing := EnsureHostAgentBinaries("v1.0.0"); missing != nil { + t.Fatalf("expected restore success") + } +} + +func TestDownloadAndInstallHostAgentBinariesErrors(t *testing.T) { + t.Run("MkdirAllError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") } + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected mkdir error") + } + }) + + t.Run("CreateTempError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + createTempFn = func(string, string) (*os.File, error) { return nil, errors.New("temp fail") } + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected temp error") + } + }) + + t.Run("DownloadError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network") + }), + } + downloadURLForVersion = func(string) string { return "http://example/bundle.tar.gz" } + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected download error") + } + }) + + t.Run("StatusError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad")) + })) + t.Cleanup(server.Close) + + httpClient = server.Client() + downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected status error") + } + }) + + t.Run("CopyError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("data")) + })) + t.Cleanup(server.Close) + + httpClient = server.Client() + downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" } + copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected copy error") + } + }) + + t.Run("CopyReadError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + body := io.NopCloser(&errorReader{}) + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: body, + }, nil + }), + } + downloadURLForVersion = func(string) string { return "http://example/bundle.tar.gz" } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected copy read error") + } + }) + + t.Run("CloseError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("data")) + })) + t.Cleanup(server.Close) + + httpClient = server.Client() + downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" } + closeFileFn = func(*os.File) error { return errors.New("close fail") } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected close error") + } + }) + + t.Run("ExtractError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("not a tarball")) + })) + t.Cleanup(server.Close) + + httpClient = server.Client() + downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected extract error") + } + }) +} + +func TestDownloadAndInstallHostAgentBinariesSuccess(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + payload := buildTarGz(t, []tarEntry{ + {name: "README.md", body: []byte("skip"), mode: 0o644}, + {name: "bin/", typeflag: tar.TypeDir, mode: 0o755}, + {name: "bin/not-agent.txt", body: []byte("skip"), mode: 0o644}, + {name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644}, + {name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"}, + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(payload) + })) + t.Cleanup(server.Close) + + httpClient = server.Client() + downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" } + symlinkFn = func(string, string) error { return errors.New("no symlink") } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", tmpDir); err != nil { + t.Fatalf("expected success, got %v", err) + } + + if _, err := os.Stat(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64")); err != nil { + t.Fatalf("expected binary installed: %v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "pulse-host-agent-linux-amd64.exe")); err != nil { + t.Fatalf("expected symlink fallback copy: %v", err) + } +} + +func TestExtractHostAgentBinariesErrors(t *testing.T) { + t.Run("OpenError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + openFileFn = func(string) (*os.File, error) { return nil, errors.New("open fail") } + if err := extractHostAgentBinaries("missing.tar.gz", t.TempDir()); err == nil { + t.Fatalf("expected open error") + } + }) + + t.Run("GzipError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmp := filepath.Join(t.TempDir(), "bad.gz") + if err := os.WriteFile(tmp, []byte("bad"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := extractHostAgentBinaries(tmp, t.TempDir()); err == nil { + t.Fatalf("expected gzip error") + } + }) + + t.Run("TarReadError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + _, _ = gzw.Write([]byte("not a tar")) + _ = gzw.Close() + + tmp := filepath.Join(t.TempDir(), "bad.tar.gz") + if err := os.WriteFile(tmp, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := extractHostAgentBinaries(tmp, t.TempDir()); err == nil { + t.Fatalf("expected tar read error") + } + }) +} + +func TestExtractHostAgentBinariesRemoveError(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + payload := buildTarGz(t, []tarEntry{ + {name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644}, + {name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"}, + }) + + archive := filepath.Join(t.TempDir(), "bundle.tar.gz") + if err := os.WriteFile(archive, payload, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + removeFn = func(string) error { return errors.New("remove fail") } + if err := extractHostAgentBinaries(archive, tmpDir); err == nil { + t.Fatalf("expected remove error") + } +} + +func TestExtractHostAgentBinariesSymlinkCopyError(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + payload := buildTarGz(t, []tarEntry{ + {name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644}, + {name: "bin/pulse-host-agent-linux-amd64.exe", typeflag: tar.TypeSymlink, linkname: "pulse-host-agent-linux-amd64"}, + }) + archive := filepath.Join(t.TempDir(), "bundle.tar.gz") + if err := os.WriteFile(archive, payload, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + symlinkFn = func(string, string) error { return errors.New("no symlink") } + openFileFn = func(path string) (*os.File, error) { + if path == archive { + return os.Open(path) + } + return nil, errors.New("open fail") + } + + if err := extractHostAgentBinaries(archive, tmpDir); err == nil { + t.Fatalf("expected symlink fallback error") + } +} + +func TestWriteHostAgentFileErrors(t *testing.T) { + t.Run("MkdirAllError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") } + if err := writeHostAgentFile("dest", strings.NewReader("data"), 0o644); err == nil { + t.Fatalf("expected mkdir error") + } + }) + + t.Run("CreateTempError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + createTempFn = func(string, string) (*os.File, error) { return nil, errors.New("temp fail") } + if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil { + t.Fatalf("expected temp error") + } + }) + + t.Run("CopyError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") } + if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil { + t.Fatalf("expected copy error") + } + }) + + t.Run("ChmodError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + chmodFileFn = func(*os.File, os.FileMode) error { return errors.New("chmod fail") } + if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil { + t.Fatalf("expected chmod error") + } + }) + + t.Run("CloseError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + closeFileFn = func(*os.File) error { return errors.New("close fail") } + if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil { + t.Fatalf("expected close error") + } + }) + + t.Run("RenameError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + renameFn = func(string, string) error { return errors.New("rename fail") } + if err := writeHostAgentFile(filepath.Join(t.TempDir(), "dest"), strings.NewReader("data"), 0o644); err == nil { + t.Fatalf("expected rename error") + } + }) +} + +func TestWriteHostAgentFileSuccess(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + dest := filepath.Join(t.TempDir(), "pulse-host-agent-linux-amd64") + if err := writeHostAgentFile(dest, strings.NewReader("data"), 0o644); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, err := os.Stat(dest); err != nil { + t.Fatalf("expected file written: %v", err) + } +} + +func TestCopyHostAgentFileErrors(t *testing.T) { + t.Run("OpenError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + openFileFn = func(string) (*os.File, error) { return nil, errors.New("open fail") } + if err := copyHostAgentFile("missing", filepath.Join(t.TempDir(), "dest")); err == nil { + t.Fatalf("expected open error") + } + }) + + t.Run("MkdirAllError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + src := filepath.Join(t.TempDir(), "src") + if err := os.WriteFile(src, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") } + if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil { + t.Fatalf("expected mkdir error") + } + }) + + t.Run("OpenFileError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + src := filepath.Join(t.TempDir(), "src") + if err := os.WriteFile(src, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + openFileModeFn = func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("create fail") + } + if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil { + t.Fatalf("expected open file error") + } + }) + + t.Run("CopyError", func(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + src := filepath.Join(t.TempDir(), "src") + if err := os.WriteFile(src, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") } + if err := copyHostAgentFile(src, filepath.Join(t.TempDir(), "dest")); err == nil { + t.Fatalf("expected copy error") + } + }) +} + +func TestCopyHostAgentFileSuccess(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + src := filepath.Join(t.TempDir(), "src") + if err := os.WriteFile(src, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + dest := filepath.Join(t.TempDir(), "dest") + if err := copyHostAgentFile(src, dest); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, err := os.Stat(dest); err != nil { + t.Fatalf("expected dest file: %v", err) + } +} + +func TestExtractHostAgentBinariesWriteError(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + tmpDir := t.TempDir() + payload := buildTarGz(t, []tarEntry{ + {name: "bin/pulse-host-agent-linux-amd64", body: []byte("binary"), mode: 0o644}, + }) + archive := filepath.Join(t.TempDir(), "bundle.tar.gz") + if err := os.WriteFile(archive, payload, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + mkdirAllFn = func(string, os.FileMode) error { return errors.New("mkdir fail") } + if err := extractHostAgentBinaries(archive, tmpDir); err == nil { + t.Fatalf("expected write error") + } +} + +func TestDownloadAndInstallHostAgentBinaries_Context(t *testing.T) { + restore := saveHostAgentHooks() + t.Cleanup(restore) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Context().Err() != nil { + t.Fatalf("unexpected context error") + } + _, _ = w.Write([]byte("data")) + })) + t.Cleanup(server.Close) + + httpClient = server.Client() + downloadURLForVersion = func(string) string { return server.URL + "/bundle.tar.gz" } + copyFn = func(io.Writer, io.Reader) (int64, error) { return 0, errors.New("copy fail") } + + if err := DownloadAndInstallHostAgentBinaries("v1.0.0", t.TempDir()); err == nil { + t.Fatalf("expected copy error") + } +} diff --git a/internal/agentexec/server.go b/internal/agentexec/server.go index 219a763aa..ce5299f23 100644 --- a/internal/agentexec/server.go +++ b/internal/agentexec/server.go @@ -20,19 +20,26 @@ var upgrader = websocket.Upgrader{ }, } +var ( + jsonMarshal = json.Marshal + pingInterval = 5 * time.Second + pingWriteWait = 5 * time.Second + readFileTimeout = 30 * time.Second +) + // Server manages WebSocket connections from agents type Server struct { mu sync.RWMutex - agents map[string]*agentConn // agentID -> connection + agents map[string]*agentConn // agentID -> connection pendingReqs map[string]chan CommandResultPayload // requestID -> response channel validateToken func(token string) bool } type agentConn struct { - conn *websocket.Conn - agent ConnectedAgent - writeMu sync.Mutex - done chan struct{} + conn *websocket.Conn + agent ConnectedAgent + writeMu sync.Mutex + done chan struct{} } // NewServer creates a new agent execution server @@ -94,7 +101,7 @@ func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { } // Parse registration payload - payloadBytes, err := json.Marshal(msg.Payload) + payloadBytes, err := jsonMarshal(msg.Payload) if err != nil { log.Error().Err(err).Msg("Failed to marshal registration payload") conn.Close() @@ -258,7 +265,7 @@ func (s *Server) readLoop(ac *agentConn) { } func (s *Server) pingLoop(ac *agentConn, done chan struct{}) { - ticker := time.NewTicker(5 * time.Second) + ticker := time.NewTicker(pingInterval) defer ticker.Stop() // Track consecutive ping failures to detect dead connections faster @@ -273,7 +280,7 @@ func (s *Server) pingLoop(ac *agentConn, done chan struct{}) { return case <-ticker.C: ac.writeMu.Lock() - err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)) + err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(pingWriteWait)) ac.writeMu.Unlock() if err != nil { consecutiveFailures++ @@ -283,14 +290,14 @@ func (s *Server) pingLoop(ac *agentConn, done chan struct{}) { Str("hostname", ac.agent.Hostname). Int("consecutive_failures", consecutiveFailures). Msg("Failed to send ping to agent") - + if consecutiveFailures >= maxConsecutiveFailures { log.Error(). Str("agent_id", ac.agent.AgentID). Str("hostname", ac.agent.Hostname). Int("failures", consecutiveFailures). Msg("Agent connection appears dead after multiple ping failures, closing connection") - + // Close the connection - this will cause readLoop to exit and clean up ac.conn.Close() return @@ -404,7 +411,7 @@ func (s *Server) ReadFile(ctx context.Context, agentID string, req ReadFilePaylo } // Wait for result - timeout := 30 * time.Second + timeout := readFileTimeout select { case result := <-respCh: return &result, nil diff --git a/internal/agentexec/server_coverage_test.go b/internal/agentexec/server_coverage_test.go new file mode 100644 index 000000000..448f131e7 --- /dev/null +++ b/internal/agentexec/server_coverage_test.go @@ -0,0 +1,537 @@ +package agentexec + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type noHijackResponseWriter struct { + header http.Header +} + +func (w *noHijackResponseWriter) Header() http.Header { + return w.header +} + +func (w *noHijackResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (w *noHijackResponseWriter) WriteHeader(int) {} + +func newConnPair(t *testing.T) (*websocket.Conn, *websocket.Conn, func()) { + t.Helper() + + serverConnCh := make(chan *websocket.Conn, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + serverConnCh <- conn + })) + + clientConn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + ts.Close() + t.Fatalf("Dial: %v", err) + } + + var serverConn *websocket.Conn + select { + case serverConn = <-serverConnCh: + case <-time.After(2 * time.Second): + clientConn.Close() + ts.Close() + t.Fatal("timed out waiting for server connection") + } + + cleanup := func() { + clientConn.Close() + serverConn.Close() + ts.Close() + } + + return serverConn, clientConn, cleanup +} + +func TestHandleWebSocket_UpgradeFailureAndDeadlineErrors(t *testing.T) { + s := NewServer(nil) + req := httptest.NewRequest(http.MethodGet, "http://example/ws", nil) + s.HandleWebSocket(&noHijackResponseWriter{header: make(http.Header)}, req) +} + +func TestHandleWebSocket_RegistrationReadError(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + conn.Close() +} + +func TestHandleWebSocket_RegistrationMessageJSONError(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + if err := conn.WriteMessage(websocket.TextMessage, []byte("{")); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + if _, _, err := conn.ReadMessage(); err == nil { + t.Fatalf("expected server to close on invalid JSON") + } +} + +func TestHandleWebSocket_RegistrationPayloadMarshalError(t *testing.T) { + orig := jsonMarshal + t.Cleanup(func() { jsonMarshal = orig }) + jsonMarshal = func(any) ([]byte, error) { + return nil, errors.New("boom") + } + + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Token: "any", + }, + }) + + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + if _, _, err := conn.ReadMessage(); err == nil { + t.Fatalf("expected server to close on marshal error") + } +} + +func TestHandleWebSocket_RegistrationPayloadUnmarshalError(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"agent_register","payload":"oops"}`)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + if _, _, err := conn.ReadMessage(); err == nil { + t.Fatalf("expected server to close on invalid payload") + } +} + +func TestHandleWebSocket_PongHandler(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Token: "any", + }, + }) + _ = wsReadRegisteredPayload(t, conn) + + if err := conn.WriteControl(websocket.PongMessage, []byte("pong"), time.Now().Add(time.Second)); err != nil { + t.Fatalf("WriteControl pong: %v", err) + } + + conn.Close() + waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") }) +} + +func TestReadLoopDone(t *testing.T) { + s := NewServer(nil) + serverConn, clientConn, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + close(ac.done) + + s.mu.Lock() + s.agents["a1"] = ac + s.mu.Unlock() + + s.readLoop(ac) + + if s.IsAgentConnected("a1") { + t.Fatalf("expected agent to be removed") + } + clientConn.Close() +} + +func TestReadLoopUnexpectedCloseError(t *testing.T) { + s := NewServer(nil) + serverConn, clientConn, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + + s.mu.Lock() + s.agents["a1"] = ac + s.mu.Unlock() + + done := make(chan struct{}) + go func() { + s.readLoop(ac) + close(done) + }() + + _ = clientConn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseProtocolError, "bye"), + time.Now().Add(time.Second), + ) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("readLoop did not exit") + } +} + +func TestReadLoopCommandResultBranches(t *testing.T) { + s := NewServer(nil) + serverConn, clientConn, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + + s.mu.Lock() + s.agents["a1"] = ac + s.pendingReqs["req-full"] = make(chan CommandResultPayload) + s.mu.Unlock() + + done := make(chan struct{}) + go func() { + s.readLoop(ac) + close(done) + }() + + _ = clientConn.WriteMessage(websocket.TextMessage, []byte("{")) + _ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":123}}`)) + _ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-full","success":true}}`)) + _ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-missing","success":true}}`)) + + clientConn.Close() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("readLoop did not exit") + } + + s.mu.Lock() + delete(s.pendingReqs, "req-full") + s.mu.Unlock() +} + +func TestPingLoopSuccessAndStop(t *testing.T) { + origInterval := pingInterval + t.Cleanup(func() { pingInterval = origInterval }) + pingInterval = 5 * time.Millisecond + + s := NewServer(nil) + serverConn, _, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + + stop := make(chan struct{}) + exited := make(chan struct{}) + go func() { + s.pingLoop(ac, stop) + close(exited) + }() + + time.Sleep(2 * pingInterval) + close(stop) + + select { + case <-exited: + case <-time.After(2 * time.Second): + t.Fatalf("pingLoop did not exit") + } +} + +func TestPingLoopFailuresClose(t *testing.T) { + origInterval := pingInterval + t.Cleanup(func() { pingInterval = origInterval }) + pingInterval = 5 * time.Millisecond + + s := NewServer(nil) + serverConn, _, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + + serverConn.Close() + + stop := make(chan struct{}) + exited := make(chan struct{}) + go func() { + s.pingLoop(ac, stop) + close(exited) + }() + + select { + case <-exited: + case <-time.After(2 * time.Second): + t.Fatalf("pingLoop did not exit after failures") + } +} + +func TestSendMessageMarshalError(t *testing.T) { + s := NewServer(nil) + if err := s.sendMessage(nil, Message{Payload: make(chan int)}); err == nil { + t.Fatalf("expected marshal error") + } +} + +func TestExecuteCommandSendError(t *testing.T) { + s := NewServer(nil) + serverConn, _, cleanup := newConnPair(t) + defer cleanup() + + serverConn.Close() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + s.mu.Lock() + s.agents["a1"] = ac + s.mu.Unlock() + + _, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r1", Timeout: 1}) + if err == nil { + t.Fatalf("expected send error") + } +} + +func TestExecuteCommandTimeoutAndCancel(t *testing.T) { + s := NewServer(nil) + serverConn, _, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + s.mu.Lock() + s.agents["a1"] = ac + s.mu.Unlock() + + _, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-timeout", Timeout: 1}) + if err == nil || !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error, got %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{RequestID: "r-cancel", Timeout: 1}) + if err == nil { + t.Fatalf("expected cancel error") + } +} + +func TestExecuteCommandDefaultTimeout(t *testing.T) { + s := NewServer(nil) + serverConn, _, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + s.mu.Lock() + s.agents["a1"] = ac + s.mu.Unlock() + + go func() { + for { + s.mu.RLock() + ch := s.pendingReqs["r-default"] + s.mu.RUnlock() + if ch != nil { + ch <- CommandResultPayload{RequestID: "r-default", Success: true} + return + } + time.Sleep(2 * time.Millisecond) + } + }() + + result, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-default"}) + if err != nil || result == nil || !result.Success { + t.Fatalf("expected success, got result=%v err=%v", result, err) + } +} + +func TestReadFileRoundTrip(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Token: "any", + }, + }) + _ = wsReadRegisteredPayload(t, conn) + + agentDone := make(chan error, 1) + go func() { + for { + msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second) + if err != nil { + agentDone <- err + return + } + if msg.Type != MsgTypeReadFile || msg.Payload == nil { + continue + } + var payload ReadFilePayload + if err := json.Unmarshal(*msg.Payload, &payload); err != nil { + agentDone <- err + return + } + agentDone <- conn.WriteJSON(Message{ + Type: MsgTypeCommandResult, + Timestamp: time.Now(), + Payload: CommandResultPayload{ + RequestID: payload.RequestID, + Success: true, + Stdout: "data", + ExitCode: 0, + }, + }) + return + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-1", Path: "/etc/hosts"}) + if err != nil || result == nil || result.Stdout != "data" { + t.Fatalf("unexpected read file result=%v err=%v", result, err) + } + + if err := <-agentDone; err != nil { + t.Fatalf("agent error: %v", err) + } +} + +func TestReadFileTimeoutCancelAndSendError(t *testing.T) { + origTimeout := readFileTimeout + t.Cleanup(func() { readFileTimeout = origTimeout }) + readFileTimeout = 10 * time.Millisecond + + s := NewServer(nil) + serverConn, _, cleanup := newConnPair(t) + defer cleanup() + + ac := &agentConn{ + conn: serverConn, + agent: ConnectedAgent{AgentID: "a1"}, + done: make(chan struct{}), + } + s.mu.Lock() + s.agents["a1"] = ac + s.mu.Unlock() + + if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-timeout"}); err == nil { + t.Fatalf("expected timeout error") + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-cancel"}); err == nil { + t.Fatalf("expected cancel error") + } + + serverConn.Close() + if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-send"}); err == nil { + t.Fatalf("expected send error") + } +} diff --git a/internal/agentexec/server_test.go b/internal/agentexec/server_test.go index 17ceb5e5e..4c78b6c86 100644 --- a/internal/agentexec/server_test.go +++ b/internal/agentexec/server_test.go @@ -42,6 +42,9 @@ func TestConnectedAgentLookups(t *testing.T) { if !ok || agentID != "a2" { t.Fatalf("expected GetAgentForHost(host2) = (a2, true), got (%q, %v)", agentID, ok) } + if _, ok := s.GetAgentForHost("missing"); ok { + t.Fatalf("expected missing host to return false") + } agents := s.GetConnectedAgents() if len(agents) != 2 { diff --git a/internal/agentupdate/coverage_test.go b/internal/agentupdate/coverage_test.go new file mode 100644 index 000000000..937053f9f --- /dev/null +++ b/internal/agentupdate/coverage_test.go @@ -0,0 +1,1033 @@ +package agentupdate + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +func testBinary() []byte { + return []byte{0x7f, 'E', 'L', 'F', 0x01, 0x02, 0x03, 0x04} +} + +func checksum(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func writeTempExec(t *testing.T) (string, string) { + t.Helper() + + dir := t.TempDir() + execPath := filepath.Join(dir, "agent") + if err := os.WriteFile(execPath, []byte("old-binary"), 0755); err != nil { + t.Fatalf("write exec: %v", err) + } + return dir, execPath +} + +func newUpdaterForTest(serverURL string) *Updater { + cfg := Config{ + PulseURL: serverURL, + AgentName: "pulse-agent", + CurrentVersion: "1.0.0", + CheckInterval: 10 * time.Millisecond, + } + return New(cfg) +} + +func TestRestartProcess(t *testing.T) { + orig := execFn + t.Cleanup(func() { execFn = orig }) + + execFn = func(string, []string, []string) error { return nil } + if err := restartProcess("/bin/true"); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + execFn = func(string, []string, []string) error { return errors.New("boom") } + if err := restartProcess("/bin/false"); err == nil { + t.Fatalf("expected error") + } +} + +func TestDetermineArchOverrides(t *testing.T) { + origOS, origArch, origUname := runtimeGOOS, runtimeGOARCH, unameCommand + t.Cleanup(func() { + runtimeGOOS = origOS + runtimeGOARCH = origArch + unameCommand = origUname + }) + + runtimeGOOS = "linux" + runtimeGOARCH = "arm" + if got := determineArch(); got != "linux-armv7" { + t.Fatalf("expected linux-armv7, got %q", got) + } + + runtimeGOARCH = "386" + if got := determineArch(); got != "linux-386" { + t.Fatalf("expected linux-386, got %q", got) + } + + runtimeGOOS = "solaris" + runtimeGOARCH = "amd64" + unameCommand = func() ([]byte, error) { return []byte("aarch64"), nil } + if got := determineArch(); got != "linux-arm64" { + t.Fatalf("expected linux-arm64, got %q", got) + } + + unameCommand = func() ([]byte, error) { return []byte("x86_64"), nil } + if got := determineArch(); got != "linux-amd64" { + t.Fatalf("expected linux-amd64, got %q", got) + } + + unameCommand = func() ([]byte, error) { return []byte("armv7l"), nil } + if got := determineArch(); got != "linux-armv7" { + t.Fatalf("expected linux-armv7, got %q", got) + } + + unameCommand = func() ([]byte, error) { return []byte("mips"), nil } + if got := determineArch(); got != "" { + t.Fatalf("expected empty arch for unknown uname, got %q", got) + } + + unameCommand = func() ([]byte, error) { return nil, errors.New("fail") } + if got := determineArch(); got != "" { + t.Fatalf("expected empty arch on uname error, got %q", got) + } +} + +func TestUnameCommandDefault(t *testing.T) { + if _, err := exec.LookPath("uname"); err != nil { + t.Skip("uname not available") + } + + orig := unameCommand + t.Cleanup(func() { unameCommand = orig }) + unameCommand = orig + + if _, err := unameCommand(); err != nil { + t.Fatalf("expected uname to run, got %v", err) + } +} + +func TestVerifyBinaryMagicOverrides(t *testing.T) { + origOS := runtimeGOOS + t.Cleanup(func() { runtimeGOOS = origOS }) + + tmpDir := t.TempDir() + + runtimeGOOS = "darwin" + machoPath := filepath.Join(tmpDir, "macho") + machoData := []byte{0xcf, 0xfa, 0xed, 0xfe, 0x00} + if err := os.WriteFile(machoPath, machoData, 0644); err != nil { + t.Fatalf("write macho: %v", err) + } + if err := verifyBinaryMagic(machoPath); err != nil { + t.Fatalf("expected macho to validate, got %v", err) + } + + badPath := filepath.Join(tmpDir, "macho-bad") + if err := os.WriteFile(badPath, []byte{0x00, 0x00, 0x00, 0x00}, 0644); err != nil { + t.Fatalf("write bad macho: %v", err) + } + if err := verifyBinaryMagic(badPath); err == nil { + t.Fatalf("expected macho error") + } + + runtimeGOOS = "windows" + pePath := filepath.Join(tmpDir, "pe.exe") + if err := os.WriteFile(pePath, []byte{'M', 'Z', 0x00, 0x00}, 0644); err != nil { + t.Fatalf("write pe: %v", err) + } + if err := verifyBinaryMagic(pePath); err != nil { + t.Fatalf("expected PE to validate, got %v", err) + } + + badPEPath := filepath.Join(tmpDir, "bad-pe.exe") + if err := os.WriteFile(badPEPath, []byte{0x00, 0x00, 0x00, 0x00}, 0644); err != nil { + t.Fatalf("write bad pe: %v", err) + } + if err := verifyBinaryMagic(badPEPath); err == nil { + t.Fatalf("expected PE error") + } + + runtimeGOOS = "plan9" + planPath := filepath.Join(tmpDir, "plan9") + if err := os.WriteFile(planPath, []byte{0x00, 0x01, 0x02, 0x03}, 0644); err != nil { + t.Fatalf("write plan9: %v", err) + } + if err := verifyBinaryMagic(planPath); err != nil { + t.Fatalf("expected unknown OS to skip verification, got %v", err) + } +} + +func TestIsUnraidOverride(t *testing.T) { + orig := unraidVersionPath + t.Cleanup(func() { unraidVersionPath = orig }) + + tmpDir := t.TempDir() + unraidVersionPath = filepath.Join(tmpDir, "unraid-version") + if isUnraid() { + t.Fatalf("expected false when file missing") + } + if err := os.WriteFile(unraidVersionPath, []byte("6.12"), 0644); err != nil { + t.Fatalf("write unraid: %v", err) + } + if !isUnraid() { + t.Fatalf("expected true when file exists") + } +} + +func TestGetServerVersion(t *testing.T) { + t.Run("Success", func(t *testing.T) { + var sawToken bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-API-Token") == "token" && strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { + sawToken = true + } + _, _ = w.Write([]byte(`{"version":"1.2.3"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.cfg.APIToken = "token" + u.client = server.Client() + + version, err := u.getServerVersion(context.Background()) + if err != nil { + t.Fatalf("getServerVersion error: %v", err) + } + if version != "1.2.3" { + t.Fatalf("expected version 1.2.3, got %q", version) + } + if !sawToken { + t.Fatalf("expected token headers to be set") + } + }) + + t.Run("InvalidURL", func(t *testing.T) { + u := newUpdaterForTest("http://[::1") + if _, err := u.getServerVersion(context.Background()); err == nil { + t.Fatalf("expected error for invalid URL") + } + }) + + t.Run("RequestError", func(t *testing.T) { + u := newUpdaterForTest("http://example") + u.client = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("boom") + }), + } + if _, err := u.getServerVersion(context.Background()); err == nil { + t.Fatalf("expected request error") + } + }) + + t.Run("StatusError", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + if _, err := u.getServerVersion(context.Background()); err == nil { + t.Fatalf("expected status error") + } + }) + + t.Run("DecodeError", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("{")) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + if _, err := u.getServerVersion(context.Background()); err == nil { + t.Fatalf("expected decode error") + } + }) +} + +func TestCheckAndUpdateBranches(t *testing.T) { + t.Run("Disabled", func(t *testing.T) { + u := newUpdaterForTest("http://example") + u.cfg.Disabled = true + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not update when disabled") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("DevCurrent", func(t *testing.T) { + u := newUpdaterForTest("http://example") + u.cfg.CurrentVersion = "dev" + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not update in dev mode") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("NoPulseURL", func(t *testing.T) { + u := newUpdaterForTest("") + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not update without URL") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("ServerError", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not update on server error") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("ServerDev", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"version":"dev"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not update when server dev") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("UpToDate", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"version":"1.0.0"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not update when up to date") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("ServerOlder", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"version":"0.9.0"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + u.performUpdateFn = func(context.Context) error { + t.Fatalf("should not downgrade") + return nil + } + u.CheckAndUpdate(context.Background()) + }) + + t.Run("ServerNewer", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"version":"1.1.0"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + var called int32 + u.performUpdateFn = func(context.Context) error { + atomic.AddInt32(&called, 1) + return nil + } + u.CheckAndUpdate(context.Background()) + if called != 1 { + t.Fatalf("expected update to be called") + } + }) + + t.Run("UpdateError", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"version":"1.1.0"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + var called int32 + u.performUpdateFn = func(context.Context) error { + atomic.AddInt32(&called, 1) + return errors.New("fail") + } + u.CheckAndUpdate(context.Background()) + if called != 1 { + t.Fatalf("expected update to be called") + } + }) +} + +func TestRunLoop(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"version":"1.1.0"}`)) + })) + t.Cleanup(server.Close) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + u.initialDelay = 0 + u.newTicker = func(d time.Duration) *time.Ticker { + return time.NewTicker(5 * time.Millisecond) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var called int32 + u.performUpdateFn = func(context.Context) error { + if atomic.AddInt32(&called, 1) >= 2 { + cancel() + } + return nil + } + + done := make(chan struct{}) + go func() { + u.RunLoop(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatalf("RunLoop did not exit") + } + + if atomic.LoadInt32(&called) < 1 { + t.Fatalf("expected RunLoop to invoke update") + } +} + +func TestRunLoopEarlyExit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + u := newUpdaterForTest("http://example") + u.cfg.Disabled = true + + done := make(chan struct{}) + go func() { + u.RunLoop(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fatalf("expected RunLoop to exit quickly") + } + + u2 := newUpdaterForTest("http://example") + u2.cfg.CurrentVersion = "dev" + + done2 := make(chan struct{}) + go func() { + u2.RunLoop(context.Background()) + close(done2) + }() + + select { + case <-done2: + case <-time.After(50 * time.Millisecond): + t.Fatalf("expected RunLoop to exit for dev") + } +} + +func TestRunLoopInitialCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + u := newUpdaterForTest("http://example") + u.initialDelay = 5 * time.Second + + done := make(chan struct{}) + go func() { + u.RunLoop(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fatalf("expected RunLoop to exit on cancel") + } +} + +func TestGetUpdatedFromVersion(t *testing.T) { + origExec := osExecutableFn + origEval := evalSymlinksFn + t.Cleanup(func() { + osExecutableFn = origExec + evalSymlinksFn = origEval + }) + + tmpDir := t.TempDir() + execPath := filepath.Join(tmpDir, "agent") + infoPath := filepath.Join(tmpDir, ".pulse-update-info") + if err := os.WriteFile(infoPath, []byte("1.0.0"), 0644); err != nil { + t.Fatalf("write update info: %v", err) + } + + osExecutableFn = func() (string, error) { return execPath, nil } + evalSymlinksFn = func(string) (string, error) { return execPath, nil } + + version := GetUpdatedFromVersion() + if version != "1.0.0" { + t.Fatalf("expected version 1.0.0, got %q", version) + } + if _, err := os.Stat(infoPath); err == nil { + t.Fatalf("expected update info file to be removed") + } + + missingPath := filepath.Join(tmpDir, "agent-missing") + osExecutableFn = func() (string, error) { return missingPath, nil } + evalSymlinksFn = func(string) (string, error) { return "", errors.New("fail") } + if GetUpdatedFromVersion() != "" { + t.Fatalf("expected empty version on missing file") + } + + osExecutableFn = func() (string, error) { return "", errors.New("fail") } + if GetUpdatedFromVersion() != "" { + t.Fatalf("expected empty version on error") + } +} + +func TestPerformUpdateWrapper(t *testing.T) { + t.Run("ExecPathError", func(t *testing.T) { + origExec := osExecutableFn + t.Cleanup(func() { osExecutableFn = origExec }) + osExecutableFn = func() (string, error) { return "", errors.New("fail") } + + u := newUpdaterForTest("http://example") + if err := u.performUpdate(context.Background()); err == nil { + t.Fatalf("expected exec path error") + } + }) + + t.Run("Success", func(t *testing.T) { + data := testBinary() + check := checksum(data) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Checksum-Sha256", check) + _, _ = w.Write(data) + })) + t.Cleanup(server.Close) + + _, execPath := writeTempExec(t) + u := newUpdaterForTest(server.URL) + u.client = server.Client() + + origExec := osExecutableFn + origRestart := restartProcessFn + t.Cleanup(func() { + osExecutableFn = origExec + restartProcessFn = origRestart + }) + osExecutableFn = func() (string, error) { return execPath, nil } + restartProcessFn = func(string) error { return nil } + + if err := u.performUpdate(context.Background()); err != nil { + t.Fatalf("expected update success, got %v", err) + } + }) +} + +func TestPerformUpdateInvalidRequest(t *testing.T) { + u := newUpdaterForTest("http://[::1") + if err := u.performUpdateWithExecPath(context.Background(), ""); err == nil { + t.Fatalf("expected error for invalid URL") + } +} + +func TestPerformUpdateDownloadErrorAndHeaders(t *testing.T) { + _, execPath := writeTempExec(t) + u := newUpdaterForTest("http://example") + u.cfg.APIToken = "token" + + var sawToken bool + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Header.Get("X-API-Token") == "token" && strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { + sawToken = true + } + return nil, errors.New("download fail") + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected download error") + } + if !sawToken { + t.Fatalf("expected auth headers to be set") + } +} + +func TestPerformUpdateStatusFallbackAndSuccess(t *testing.T) { + data := testBinary() + check := checksum(data) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "arch=") { + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("X-Checksum-Sha256", check) + _, _ = w.Write(data) + })) + t.Cleanup(server.Close) + + _, execPath := writeTempExec(t) + + u := newUpdaterForTest(server.URL) + u.client = server.Client() + + origRestart := restartProcessFn + t.Cleanup(func() { restartProcessFn = origRestart }) + restartProcessFn = func(string) error { return nil } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil { + t.Fatalf("expected update success, got %v", err) + } + + updated, err := os.ReadFile(execPath) + if err != nil { + t.Fatalf("read updated exec: %v", err) + } + if !bytes.HasPrefix(updated, []byte{0x7f, 'E', 'L', 'F'}) { + t.Fatalf("expected updated binary content") + } +} + +func TestPerformUpdateSymlinkFallback(t *testing.T) { + data := testBinary() + check := checksum(data) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Checksum-Sha256", check) + _, _ = w.Write(data) + })) + t.Cleanup(server.Close) + + _, execPath := writeTempExec(t) + u := newUpdaterForTest(server.URL) + u.client = server.Client() + + origEval := evalSymlinksFn + origRestart := restartProcessFn + t.Cleanup(func() { + evalSymlinksFn = origEval + restartProcessFn = origRestart + }) + evalSymlinksFn = func(string) (string, error) { return "", errors.New("fail") } + restartProcessFn = func(string) error { return nil } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil { + t.Fatalf("expected update success with symlink fallback, got %v", err) + } +} + +func TestPerformUpdateErrors(t *testing.T) { + t.Run("CreateTempError", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Checksum-Sha256", checksum(testBinary())) + _, _ = w.Write(testBinary()) + })) + t.Cleanup(server.Close) + + _, execPath := writeTempExec(t) + u := newUpdaterForTest(server.URL) + u.client = server.Client() + + origCreate := createTempFn + t.Cleanup(func() { createTempFn = origCreate }) + createTempFn = func(string, string) (*os.File, error) { + return nil, errors.New("temp fail") + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected create temp error") + } + }) + + t.Run("CopyError", func(t *testing.T) { + _, execPath := writeTempExec(t) + u := newUpdaterForTest("http://example") + + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + body := io.NopCloser(&errorReader{}) + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: body, + Header: http.Header{"X-Checksum-Sha256": []string{checksum(testBinary())}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected copy error") + } + }) + + t.Run("TooLarge", func(t *testing.T) { + _, execPath := writeTempExec(t) + u := newUpdaterForTest("http://example") + + origMax := maxBinarySizeBytes + t.Cleanup(func() { maxBinarySizeBytes = origMax }) + maxBinarySizeBytes = 4 + + data := []byte{0x7f, 'E', 'L', 'F', 0x00, 0x01} + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected size error") + } + }) + + t.Run("CloseError", func(t *testing.T) { + data := testBinary() + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + + origClose := closeFileFn + origRestart := restartProcessFn + t.Cleanup(func() { + closeFileFn = origClose + restartProcessFn = origRestart + }) + closeFileFn = func(*os.File) error { return errors.New("close fail") } + restartProcessFn = func(string) error { return nil } + + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected close error") + } + }) + + t.Run("InvalidBinary", func(t *testing.T) { + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + + data := []byte{0x00, 0x00, 0x00, 0x00} + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected invalid binary error") + } + }) + + t.Run("MissingChecksum", func(t *testing.T) { + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + + data := testBinary() + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected checksum missing error") + } + }) + + t.Run("ChecksumMismatch", func(t *testing.T) { + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + + data := testBinary() + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{"deadbeef"}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected checksum mismatch error") + } + }) + + t.Run("ChmodError", func(t *testing.T) { + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + data := testBinary() + + origChmod := chmodFn + origRestart := restartProcessFn + t.Cleanup(func() { + chmodFn = origChmod + restartProcessFn = origRestart + }) + chmodFn = func(string, os.FileMode) error { return errors.New("chmod fail") } + restartProcessFn = func(string) error { return nil } + + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected chmod error") + } + }) + + t.Run("BackupRenameError", func(t *testing.T) { + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + data := testBinary() + + origRename := renameFn + origRestart := restartProcessFn + t.Cleanup(func() { + renameFn = origRename + restartProcessFn = origRestart + }) + renameFn = func(oldPath, newPath string) error { + if strings.HasSuffix(newPath, ".backup") { + return errors.New("backup fail") + } + return origRename(oldPath, newPath) + } + restartProcessFn = func(string) error { return nil } + + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected backup rename error") + } + }) + + t.Run("ReplaceRenameError", func(t *testing.T) { + u := newUpdaterForTest("http://example") + _, execPath := writeTempExec(t) + data := testBinary() + + origRename := renameFn + origRestart := restartProcessFn + t.Cleanup(func() { + renameFn = origRename + restartProcessFn = origRestart + }) + var calls int32 + renameFn = func(oldPath, newPath string) error { + if atomic.AddInt32(&calls, 1) == 2 { + return errors.New("replace fail") + } + return origRename(oldPath, newPath) + } + restartProcessFn = func(string) error { return nil } + + u.client = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(bytes.NewReader(data)), + Header: http.Header{"X-Checksum-Sha256": []string{checksum(data)}}, + }, nil + }), + } + + if err := u.performUpdateWithExecPath(context.Background(), execPath); err == nil { + t.Fatalf("expected replace rename error") + } + }) +} + +func TestPerformUpdateUnraidPaths(t *testing.T) { + data := testBinary() + check := checksum(data) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Checksum-Sha256", check) + _, _ = w.Write(data) + })) + t.Cleanup(server.Close) + + _, execPath := writeTempExec(t) + u := newUpdaterForTest(server.URL) + u.client = server.Client() + + origUnraid := unraidVersionPath + origPersist := unraidPersistentPathFn + origRead := readFileFn + origWrite := writeFileFn + origRename := renameFn + origRestart := restartProcessFn + t.Cleanup(func() { + unraidVersionPath = origUnraid + unraidPersistentPathFn = origPersist + readFileFn = origRead + writeFileFn = origWrite + renameFn = origRename + restartProcessFn = origRestart + }) + + tmpDir := t.TempDir() + unraidVersionPath = filepath.Join(tmpDir, "unraid-version") + if err := os.WriteFile(unraidVersionPath, []byte("6.12"), 0644); err != nil { + t.Fatalf("write unraid version: %v", err) + } + + persistDir := filepath.Join(tmpDir, "persist") + if err := os.MkdirAll(persistDir, 0755); err != nil { + t.Fatalf("mkdir persist: %v", err) + } + persistPath := filepath.Join(persistDir, "pulse-agent") + if err := os.WriteFile(persistPath, []byte("old"), 0644); err != nil { + t.Fatalf("write persist: %v", err) + } + unraidPersistentPathFn = func(string) string { return persistPath } + restartProcessFn = func(string) error { return nil } + + t.Run("ReadError", func(t *testing.T) { + readFileFn = func(string) ([]byte, error) { return nil, errors.New("read fail") } + if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil { + t.Fatalf("expected update success, got %v", err) + } + }) + + t.Run("WriteError", func(t *testing.T) { + readFileFn = func(string) ([]byte, error) { return []byte("new"), nil } + writeFileFn = func(string, []byte, os.FileMode) error { return errors.New("write fail") } + if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil { + t.Fatalf("expected update success, got %v", err) + } + }) + + t.Run("RenameError", func(t *testing.T) { + readFileFn = func(string) ([]byte, error) { return []byte("new"), nil } + writeFileFn = os.WriteFile + renameFn = func(oldPath, newPath string) error { + if newPath == persistPath { + return errors.New("rename fail") + } + return os.Rename(oldPath, newPath) + } + if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil { + t.Fatalf("expected update success, got %v", err) + } + }) + + t.Run("Success", func(t *testing.T) { + readFileFn = func(string) ([]byte, error) { return []byte("new"), nil } + writeFileFn = os.WriteFile + renameFn = os.Rename + if err := u.performUpdateWithExecPath(context.Background(), execPath); err != nil { + t.Fatalf("expected update success, got %v", err) + } + }) +} + +type errorReader struct { + sent bool +} + +func (e *errorReader) Read(p []byte) (int, error) { + if e.sent { + return 0, errors.New("read fail") + } + e.sent = true + copy(p, testBinary()) + return len(testBinary()), errors.New("read fail") +} diff --git a/internal/agentupdate/restart_unix.go b/internal/agentupdate/restart_unix.go index 26bd1aa29..aa0c23725 100644 --- a/internal/agentupdate/restart_unix.go +++ b/internal/agentupdate/restart_unix.go @@ -8,13 +8,15 @@ import ( "syscall" ) +var execFn = syscall.Exec + // restartProcess replaces the current process with a new instance. // On Unix-like systems, this uses syscall.Exec for an in-place restart. func restartProcess(execPath string) error { args := os.Args env := os.Environ() - if err := syscall.Exec(execPath, args, env); err != nil { + if err := execFn(execPath, args, env); err != nil { return fmt.Errorf("failed to restart: %w", err) } diff --git a/internal/agentupdate/update.go b/internal/agentupdate/update.go index b1efb73ad..a23921a4d 100644 --- a/internal/agentupdate/update.go +++ b/internal/agentupdate/update.go @@ -32,6 +32,24 @@ const ( downloadTimeout = 5 * time.Minute ) +var ( + maxBinarySizeBytes int64 = maxBinarySize + runtimeGOOS = runtime.GOOS + runtimeGOARCH = runtime.GOARCH + unameCommand = func() ([]byte, error) { return exec.Command("uname", "-m").Output() } + unraidVersionPath = "/etc/unraid-version" + unraidPersistentPathFn = unraidPersistentPath + restartProcessFn = restartProcess + osExecutableFn = os.Executable + evalSymlinksFn = filepath.EvalSymlinks + createTempFn = os.CreateTemp + chmodFn = os.Chmod + renameFn = os.Rename + closeFileFn = func(f *os.File) error { return f.Close() } + readFileFn = os.ReadFile + writeFileFn = os.WriteFile +) + // Config holds the configuration for the updater. type Config struct { // PulseURL is the base URL of the Pulse server (e.g., "https://pulse.example.com:7655") @@ -66,6 +84,8 @@ type Updater struct { logger zerolog.Logger performUpdateFn func(context.Context) error + initialDelay time.Duration + newTicker func(time.Duration) *time.Ticker } // New creates a new Updater with the given configuration. @@ -95,6 +115,8 @@ func New(cfg Config) *Updater { logger: logger, } u.performUpdateFn = u.performUpdate + u.initialDelay = 5 * time.Second + u.newTicker = time.NewTicker return u } @@ -114,11 +136,11 @@ func (u *Updater) RunLoop(ctx context.Context) { select { case <-ctx.Done(): return - case <-time.After(5 * time.Second): + case <-time.After(u.initialDelay): u.CheckAndUpdate(ctx) } - ticker := time.NewTicker(u.cfg.CheckInterval) + ticker := u.newTicker(u.cfg.CheckInterval) defer ticker.Stop() for { @@ -228,7 +250,7 @@ func (u *Updater) getServerVersion(ctx context.Context) (string, error) { // isUnraid checks if we're running on Unraid by looking for /etc/unraid-version func isUnraid() bool { - _, err := os.Stat("/etc/unraid-version") + _, err := os.Stat(unraidVersionPath) return err == nil } @@ -245,7 +267,7 @@ func verifyBinaryMagic(path string) error { return fmt.Errorf("failed to read magic bytes: %w", err) } - switch runtime.GOOS { + switch runtimeGOOS { case "linux": // ELF magic: 0x7f 'E' 'L' 'F' if magic[0] == 0x7f && magic[1] == 'E' && magic[2] == 'L' && magic[3] == 'F' { @@ -286,10 +308,14 @@ func unraidPersistentPath(agentName string) string { // performUpdate downloads and installs the new agent binary. func (u *Updater) performUpdate(ctx context.Context) error { - execPath, err := os.Executable() + execPath, err := osExecutableFn() if err != nil { return fmt.Errorf("failed to get executable path: %w", err) } + return u.performUpdateWithExecPath(ctx, execPath) +} + +func (u *Updater) performUpdateWithExecPath(ctx context.Context, execPath string) error { // Build download URL downloadBase := fmt.Sprintf("%s/download/%s", strings.TrimRight(u.cfg.PulseURL, "/"), u.cfg.AgentName) @@ -303,7 +329,7 @@ func (u *Updater) performUpdate(ctx context.Context) error { candidates = append(candidates, downloadBase) var resp *http.Response - var lastErr error + lastErr := errors.New("failed to download binary") for _, url := range candidates { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -335,9 +361,6 @@ func (u *Updater) performUpdate(ctx context.Context) error { } if resp == nil { - if lastErr == nil { - lastErr = errors.New("failed to download binary") - } return lastErr } defer resp.Body.Close() @@ -346,7 +369,7 @@ func (u *Updater) performUpdate(ctx context.Context) error { checksumHeader := strings.TrimSpace(resp.Header.Get("X-Checksum-Sha256")) // Resolve symlinks to get the real path for atomic rename - realExecPath, err := filepath.EvalSymlinks(execPath) + realExecPath, err := evalSymlinksFn(execPath) if err != nil { // Fall back to original path if symlink resolution fails realExecPath = execPath @@ -355,7 +378,7 @@ func (u *Updater) performUpdate(ctx context.Context) error { // Create temporary file in the same directory as the target binary // to ensure atomic rename works (os.Rename fails across filesystems) targetDir := filepath.Dir(realExecPath) - tmpFile, err := os.CreateTemp(targetDir, u.cfg.AgentName+"-*.tmp") + tmpFile, err := createTempFn(targetDir, u.cfg.AgentName+"-*.tmp") if err != nil { return fmt.Errorf("failed to create temp file: %w", err) } @@ -364,17 +387,17 @@ func (u *Updater) performUpdate(ctx context.Context) error { // Write downloaded binary with checksum calculation and size limit hasher := sha256.New() - limitedReader := io.LimitReader(resp.Body, maxBinarySize+1) // +1 to detect overflow + limitedReader := io.LimitReader(resp.Body, maxBinarySizeBytes+1) // +1 to detect overflow written, err := io.Copy(tmpFile, io.TeeReader(limitedReader, hasher)) if err != nil { - tmpFile.Close() + closeFileFn(tmpFile) return fmt.Errorf("failed to write binary: %w", err) } - if written > maxBinarySize { - tmpFile.Close() - return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySize) + if written > maxBinarySizeBytes { + closeFileFn(tmpFile) + return fmt.Errorf("downloaded binary exceeds maximum size (%d bytes)", maxBinarySizeBytes) } - if err := tmpFile.Close(); err != nil { + if err := closeFileFn(tmpFile); err != nil { return fmt.Errorf("failed to close temp file: %w", err) } @@ -397,19 +420,19 @@ func (u *Updater) performUpdate(ctx context.Context) error { u.logger.Debug().Str("checksum", downloadChecksum).Msg("Checksum verified") // Make executable - if err := os.Chmod(tmpPath, 0755); err != nil { + if err := chmodFn(tmpPath, 0755); err != nil { return fmt.Errorf("failed to chmod: %w", err) } // Atomic replacement with backup (use realExecPath for rename operations) backupPath := realExecPath + ".backup" - if err := os.Rename(realExecPath, backupPath); err != nil { + if err := renameFn(realExecPath, backupPath); err != nil { return fmt.Errorf("failed to backup current binary: %w", err) } - if err := os.Rename(tmpPath, realExecPath); err != nil { + if err := renameFn(tmpPath, realExecPath); err != nil { // Restore backup on failure - os.Rename(backupPath, realExecPath) + renameFn(backupPath, realExecPath) return fmt.Errorf("failed to replace binary: %w", err) } @@ -418,26 +441,26 @@ func (u *Updater) performUpdate(ctx context.Context) error { // Write previous version to a file so the agent can report "updated from X" on next start updateInfoPath := filepath.Join(targetDir, ".pulse-update-info") - _ = os.WriteFile(updateInfoPath, []byte(u.cfg.CurrentVersion), 0644) + _ = writeFileFn(updateInfoPath, []byte(u.cfg.CurrentVersion), 0644) // On Unraid, also update the persistent copy on the flash drive // This ensures the update survives reboots if isUnraid() { - persistPath := unraidPersistentPath(u.cfg.AgentName) + persistPath := unraidPersistentPathFn(u.cfg.AgentName) if _, err := os.Stat(persistPath); err == nil { // Persistent path exists, update it u.logger.Debug().Str("path", persistPath).Msg("Updating Unraid persistent binary") // Read the newly installed binary - newBinary, err := os.ReadFile(execPath) + newBinary, err := readFileFn(execPath) if err != nil { u.logger.Warn().Err(err).Msg("Failed to read new binary for Unraid persistence") } else { // Write to persistent storage (atomic via temp file) tmpPersist := persistPath + ".tmp" - if err := os.WriteFile(tmpPersist, newBinary, 0644); err != nil { + if err := writeFileFn(tmpPersist, newBinary, 0644); err != nil { u.logger.Warn().Err(err).Msg("Failed to write Unraid persistent binary") - } else if err := os.Rename(tmpPersist, persistPath); err != nil { + } else if err := renameFn(tmpPersist, persistPath); err != nil { u.logger.Warn().Err(err).Msg("Failed to rename Unraid persistent binary") os.Remove(tmpPersist) } else { @@ -448,19 +471,19 @@ func (u *Updater) performUpdate(ctx context.Context) error { } // Restart the process using platform-specific implementation - return restartProcess(execPath) + return restartProcessFn(execPath) } // GetUpdatedFromVersion checks if the agent was recently updated and returns the previous version. // Returns empty string if no update info exists. Clears the info file after reading. func GetUpdatedFromVersion() string { - execPath, err := os.Executable() + execPath, err := osExecutableFn() if err != nil { return "" } // Resolve symlinks to get the real path - realExecPath, err := filepath.EvalSymlinks(execPath) + realExecPath, err := evalSymlinksFn(execPath) if err != nil { realExecPath = execPath } @@ -479,8 +502,8 @@ func GetUpdatedFromVersion() string { // determineArch returns the architecture string for download URLs (e.g., "linux-amd64", "darwin-arm64"). func determineArch() string { - os := runtime.GOOS - arch := runtime.GOARCH + os := runtimeGOOS + arch := runtimeGOARCH // Normalize architecture switch arch { @@ -497,7 +520,7 @@ func determineArch() string { } // Fall back to uname for edge cases on unknown OS - out, err := exec.Command("uname", "-m").Output() + out, err := unameCommand() if err != nil { return "" } diff --git a/internal/ai/intelligence.go b/internal/ai/intelligence.go index ec20e4223..1f380a740 100644 --- a/internal/ai/intelligence.go +++ b/internal/ai/intelligence.go @@ -40,68 +40,68 @@ const ( // HealthFactor represents a single component affecting health type HealthFactor struct { Name string `json:"name"` - Impact float64 `json:"impact"` // -1 to 1, negative is bad + Impact float64 `json:"impact"` // -1 to 1, negative is bad Description string `json:"description"` Category string `json:"category"` // finding, prediction, baseline, incident } // HealthScore represents the overall health of a resource or system type HealthScore struct { - Score float64 `json:"score"` // 0-100 - Grade HealthGrade `json:"grade"` // A, B, C, D, F - Trend HealthTrend `json:"trend"` // improving, stable, declining + Score float64 `json:"score"` // 0-100 + Grade HealthGrade `json:"grade"` // A, B, C, D, F + Trend HealthTrend `json:"trend"` // improving, stable, declining Factors []HealthFactor `json:"factors"` // What's affecting the score - Prediction string `json:"prediction"` // Human-readable outlook + Prediction string `json:"prediction"` // Human-readable outlook } // ResourceIntelligence aggregates all AI knowledge about a single resource type ResourceIntelligence struct { - ResourceID string `json:"resource_id"` - ResourceName string `json:"resource_name,omitempty"` - ResourceType string `json:"resource_type,omitempty"` - Health HealthScore `json:"health"` - ActiveFindings []*Finding `json:"active_findings,omitempty"` - Predictions []patterns.FailurePrediction `json:"predictions,omitempty"` - Dependencies []string `json:"dependencies,omitempty"` // Resources this depends on - Dependents []string `json:"dependents,omitempty"` // Resources that depend on this - Correlations []*correlation.Correlation `json:"correlations,omitempty"` + ResourceID string `json:"resource_id"` + ResourceName string `json:"resource_name,omitempty"` + ResourceType string `json:"resource_type,omitempty"` + Health HealthScore `json:"health"` + ActiveFindings []*Finding `json:"active_findings,omitempty"` + Predictions []patterns.FailurePrediction `json:"predictions,omitempty"` + Dependencies []string `json:"dependencies,omitempty"` // Resources this depends on + Dependents []string `json:"dependents,omitempty"` // Resources that depend on this + Correlations []*correlation.Correlation `json:"correlations,omitempty"` Baselines map[string]*baseline.FlatBaseline `json:"baselines,omitempty"` - Anomalies []AnomalyReport `json:"anomalies,omitempty"` - RecentIncidents []*memory.Incident `json:"recent_incidents,omitempty"` - Knowledge *knowledge.GuestKnowledge `json:"knowledge,omitempty"` - NoteCount int `json:"note_count"` + Anomalies []AnomalyReport `json:"anomalies,omitempty"` + RecentIncidents []*memory.Incident `json:"recent_incidents,omitempty"` + Knowledge *knowledge.GuestKnowledge `json:"knowledge,omitempty"` + NoteCount int `json:"note_count"` } // AnomalyReport describes a metric that's deviating from baseline type AnomalyReport struct { - Metric string `json:"metric"` - CurrentValue float64 `json:"current_value"` - BaselineMean float64 `json:"baseline_mean"` - ZScore float64 `json:"z_score"` + Metric string `json:"metric"` + CurrentValue float64 `json:"current_value"` + BaselineMean float64 `json:"baseline_mean"` + ZScore float64 `json:"z_score"` Severity baseline.AnomalySeverity `json:"severity"` - Description string `json:"description"` + Description string `json:"description"` } // IntelligenceSummary provides a system-wide intelligence overview type IntelligenceSummary struct { - Timestamp time.Time `json:"timestamp"` - OverallHealth HealthScore `json:"overall_health"` - + Timestamp time.Time `json:"timestamp"` + OverallHealth HealthScore `json:"overall_health"` + // Findings summary - FindingsCount FindingsCounts `json:"findings_count"` - TopFindings []*Finding `json:"top_findings,omitempty"` // Most critical - + FindingsCount FindingsCounts `json:"findings_count"` + TopFindings []*Finding `json:"top_findings,omitempty"` // Most critical + // Predictions - PredictionsCount int `json:"predictions_count"` - UpcomingRisks []patterns.FailurePrediction `json:"upcoming_risks,omitempty"` - + PredictionsCount int `json:"predictions_count"` + UpcomingRisks []patterns.FailurePrediction `json:"upcoming_risks,omitempty"` + // Recent activity - RecentChangesCount int `json:"recent_changes_count"` - RecentRemediations []memory.RemediationRecord `json:"recent_remediations,omitempty"` - + RecentChangesCount int `json:"recent_changes_count"` + RecentRemediations []memory.RemediationRecord `json:"recent_remediations,omitempty"` + // Learning progress Learning LearningStats `json:"learning"` - + // Resources needing attention ResourcesAtRisk []ResourceRiskSummary `json:"resources_at_risk,omitempty"` } @@ -137,7 +137,7 @@ type ResourceRiskSummary struct { // Intelligence orchestrates all AI subsystems into a unified system type Intelligence struct { mu sync.RWMutex - + // Core subsystems findings *FindingsStore patterns *patterns.Detector @@ -147,10 +147,13 @@ type Intelligence struct { knowledge *knowledge.Store changes *memory.ChangeDetector remediations *memory.RemediationLog - + // State access stateProvider StateProvider - + + // Optional hook for anomaly detection (used by patrol integration/tests) + anomalyDetector func(resourceID string) []AnomalyReport + // Configuration dataDir string } @@ -180,7 +183,7 @@ func (i *Intelligence) SetSubsystems( ) { i.mu.Lock() defer i.mu.Unlock() - + i.findings = findings i.patterns = patternsDetector i.correlations = correlationsDetector @@ -202,45 +205,45 @@ func (i *Intelligence) SetStateProvider(sp StateProvider) { func (i *Intelligence) GetSummary() *IntelligenceSummary { i.mu.RLock() defer i.mu.RUnlock() - + summary := &IntelligenceSummary{ Timestamp: time.Now(), } - + // Aggregate findings if i.findings != nil { all := i.findings.GetActive(FindingSeverityInfo) // Get all active findings summary.FindingsCount = i.countFindings(all) summary.TopFindings = i.getTopFindings(all, 5) } - + // Aggregate predictions if i.patterns != nil { predictions := i.patterns.GetPredictions() summary.PredictionsCount = len(predictions) summary.UpcomingRisks = i.getUpcomingRisks(predictions, 5) } - + // Aggregate recent activity if i.changes != nil { recent := i.changes.GetRecentChanges(100, time.Now().Add(-24*time.Hour)) summary.RecentChangesCount = len(recent) } - + if i.remediations != nil { recent := i.remediations.GetRecentRemediations(5, time.Now().Add(-24*time.Hour)) summary.RecentRemediations = recent } - + // Learning stats summary.Learning = i.getLearningStats() - + // Calculate overall health summary.OverallHealth = i.calculateOverallHealth(summary) - + // Resources at risk summary.ResourcesAtRisk = i.getResourcesAtRisk(5) - + return summary } @@ -248,11 +251,11 @@ func (i *Intelligence) GetSummary() *IntelligenceSummary { func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntelligence { i.mu.RLock() defer i.mu.RUnlock() - + intel := &ResourceIntelligence{ ResourceID: resourceID, } - + // Active findings if i.findings != nil { intel.ActiveFindings = i.findings.GetByResource(resourceID) @@ -261,19 +264,19 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel intel.ResourceType = intel.ActiveFindings[0].ResourceType } } - + // Predictions if i.patterns != nil { intel.Predictions = i.patterns.GetPredictionsForResource(resourceID) } - + // Correlations and dependencies if i.correlations != nil { intel.Correlations = i.correlations.GetCorrelationsForResource(resourceID) intel.Dependencies = i.correlations.GetDependsOn(resourceID) intel.Dependents = i.correlations.GetDependencies(resourceID) } - + // Baselines if i.baselines != nil { if rb, ok := i.baselines.GetResourceBaseline(resourceID); ok { @@ -290,12 +293,12 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel } } } - + // Recent incidents if i.incidents != nil { intel.RecentIncidents = i.incidents.ListIncidentsByResource(resourceID, 5) } - + // Knowledge if i.knowledge != nil { if k, err := i.knowledge.GetKnowledge(resourceID); err == nil && k != nil { @@ -309,10 +312,10 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel } } } - + // Calculate health score intel.Health = i.calculateResourceHealth(intel) - + return intel } @@ -320,49 +323,49 @@ func (i *Intelligence) GetResourceIntelligence(resourceID string) *ResourceIntel func (i *Intelligence) FormatContext(resourceID string) string { i.mu.RLock() defer i.mu.RUnlock() - + var sections []string - + // Knowledge (most important - what we've learned) if i.knowledge != nil { if ctx := i.knowledge.FormatForContext(resourceID); ctx != "" { sections = append(sections, ctx) } } - + // Baselines (what's normal for this resource) if i.baselines != nil { if ctx := i.formatBaselinesForContext(resourceID); ctx != "" { sections = append(sections, ctx) } } - + // Current anomalies if anomalies := i.detectCurrentAnomalies(resourceID); len(anomalies) > 0 { sections = append(sections, i.formatAnomaliesForContext(anomalies)) } - + // Patterns/Predictions if i.patterns != nil { if ctx := i.patterns.FormatForContext(resourceID); ctx != "" { sections = append(sections, ctx) } } - + // Correlations if i.correlations != nil { if ctx := i.correlations.FormatForContext(resourceID); ctx != "" { sections = append(sections, ctx) } } - + // Incidents if i.incidents != nil { if ctx := i.incidents.FormatForResource(resourceID, 5); ctx != "" { sections = append(sections, ctx) } } - + return strings.Join(sections, "\n") } @@ -370,37 +373,37 @@ func (i *Intelligence) FormatContext(resourceID string) string { func (i *Intelligence) FormatGlobalContext() string { i.mu.RLock() defer i.mu.RUnlock() - + var sections []string - + // All saved knowledge (limited) if i.knowledge != nil { if ctx := i.knowledge.FormatAllForContext(); ctx != "" { sections = append(sections, ctx) } } - + // Recent incidents across infrastructure if i.incidents != nil { if ctx := i.incidents.FormatForPatrol(8); ctx != "" { sections = append(sections, ctx) } } - + // Top correlations if i.correlations != nil { if ctx := i.correlations.FormatForContext(""); ctx != "" { sections = append(sections, ctx) } } - + // Top predictions if i.patterns != nil { if ctx := i.patterns.FormatForContext(""); ctx != "" { sections = append(sections, ctx) } } - + return strings.Join(sections, "\n") } @@ -408,11 +411,11 @@ func (i *Intelligence) FormatGlobalContext() string { func (i *Intelligence) RecordLearning(resourceID, resourceName, resourceType, title, content string) error { i.mu.RLock() defer i.mu.RUnlock() - + if i.knowledge == nil { return nil } - + return i.knowledge.SaveNote(resourceID, resourceName, resourceType, "learning", title, content) } @@ -420,11 +423,11 @@ func (i *Intelligence) RecordLearning(resourceID, resourceName, resourceType, ti func (i *Intelligence) CheckBaselinesForResource(resourceID string, metrics map[string]float64) []AnomalyReport { i.mu.RLock() defer i.mu.RUnlock() - + if i.baselines == nil { return nil } - + var anomalies []AnomalyReport for metric, value := range metrics { severity, zScore, bl := i.baselines.CheckAnomaly(resourceID, metric, value) @@ -439,7 +442,7 @@ func (i *Intelligence) CheckBaselinesForResource(resourceID string, metrics map[ }) } } - + return anomalies } @@ -452,7 +455,7 @@ func (i *Intelligence) CreatePredictionFinding(pred patterns.FailurePrediction) if pred.Confidence > 0.8 && pred.DaysUntil < 1 { severity = FindingSeverityCritical } - + return &Finding{ ID: fmt.Sprintf("pred-%s-%s", pred.ResourceID, pred.EventType), Key: fmt.Sprintf("prediction:%s:%s", pred.ResourceID, pred.EventType), @@ -493,7 +496,7 @@ func (i *Intelligence) getTopFindings(findings []*Finding, limit int) []*Finding if len(findings) == 0 { return nil } - + // Sort by severity (critical first) then by detection time (newest first) sorted := make([]*Finding, len(findings)) copy(sorted, findings) @@ -505,7 +508,7 @@ func (i *Intelligence) getTopFindings(findings []*Finding, limit int) []*Finding } return sorted[a].DetectedAt.After(sorted[b].DetectedAt) }) - + if len(sorted) > limit { sorted = sorted[:limit] } @@ -531,7 +534,7 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction if len(predictions) == 0 { return nil } - + // Filter to next 7 days and sort by days until var upcoming []patterns.FailurePrediction for _, p := range predictions { @@ -539,11 +542,11 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction upcoming = append(upcoming, p) } } - + sort.Slice(upcoming, func(a, b int) bool { return upcoming[a].DaysUntil < upcoming[b].DaysUntil }) - + if len(upcoming) > limit { upcoming = upcoming[:limit] } @@ -552,7 +555,7 @@ func (i *Intelligence) getUpcomingRisks(predictions []patterns.FailurePrediction func (i *Intelligence) getLearningStats() LearningStats { stats := LearningStats{} - + if i.knowledge != nil { guests, _ := i.knowledge.ListGuests() for _, guestID := range guests { @@ -562,27 +565,27 @@ func (i *Intelligence) getLearningStats() LearningStats { } } } - + if i.baselines != nil { stats.ResourcesWithBaselines = i.baselines.ResourceCount() } - + if i.patterns != nil { p := i.patterns.GetPatterns() stats.PatternsDetected = len(p) } - + if i.correlations != nil { c := i.correlations.GetCorrelations() stats.CorrelationsLearned = len(c) } - + if i.incidents != nil { // Count is not available, so we skip this stat for now // Could be added to IncidentStore if needed stats.IncidentsTracked = 0 } - + return stats } @@ -593,7 +596,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal Trend: HealthTrendStable, Factors: []HealthFactor{}, } - + // Deduct for findings if summary.FindingsCount.Critical > 0 { impact := float64(summary.FindingsCount.Critical) * 20 @@ -608,7 +611,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal Category: "finding", }) } - + if summary.FindingsCount.Warning > 0 { impact := float64(summary.FindingsCount.Warning) * 10 if impact > 20 { @@ -622,7 +625,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal Category: "finding", }) } - + // Deduct for imminent predictions for _, pred := range summary.UpcomingRisks { if pred.DaysUntil < 3 && pred.Confidence > 0.7 { @@ -636,7 +639,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal }) } } - + // Bonus for learning progress if summary.Learning.ResourcesWithKnowledge > 5 { bonus := 5.0 @@ -648,7 +651,7 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal Category: "learning", }) } - + // Clamp score if health.Score < 0 { health.Score = 0 @@ -656,13 +659,13 @@ func (i *Intelligence) calculateOverallHealth(summary *IntelligenceSummary) Heal if health.Score > 100 { health.Score = 100 } - + // Assign grade health.Grade = scoreToGrade(health.Score) - + // Generate prediction text health.Prediction = i.generateHealthPrediction(health, summary) - + return health } @@ -673,7 +676,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal Trend: HealthTrendStable, Factors: []HealthFactor{}, } - + // Deduct for active findings for _, f := range intel.ActiveFindings { if f == nil { @@ -698,7 +701,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal Category: "finding", }) } - + // Deduct for predictions for _, p := range intel.Predictions { if p.DaysUntil < 7 && p.Confidence > 0.5 { @@ -712,7 +715,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal }) } } - + // Deduct for anomalies for _, a := range intel.Anomalies { var impact float64 @@ -734,7 +737,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal Category: "baseline", }) } - + // Bonus for having knowledge if intel.NoteCount > 0 { bonus := 2.0 @@ -746,7 +749,7 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal Category: "learning", }) } - + // Clamp if health.Score < 0 { health.Score = 0 @@ -754,9 +757,9 @@ func (i *Intelligence) calculateResourceHealth(intel *ResourceIntelligence) Heal if health.Score > 100 { health.Score = 100 } - + health.Grade = scoreToGrade(health.Score) - + return health } @@ -779,21 +782,21 @@ func (i *Intelligence) generateHealthPrediction(health HealthScore, summary *Int if health.Grade == HealthGradeA { return "Infrastructure is healthy with no significant issues detected." } - + if summary.FindingsCount.Critical > 0 { return fmt.Sprintf("Immediate attention required: %d critical issues.", summary.FindingsCount.Critical) } - + if len(summary.UpcomingRisks) > 0 { risk := summary.UpcomingRisks[0] return fmt.Sprintf("Predicted %s event on resource within %.1f days (%.0f%% confidence).", risk.EventType, risk.DaysUntil, risk.Confidence*100) } - + if summary.FindingsCount.Warning > 0 { return fmt.Sprintf("%d warnings should be addressed soon to maintain stability.", summary.FindingsCount.Warning) } - + return "Infrastructure is stable with minor issues to monitor." } @@ -801,16 +804,13 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary { if i.findings == nil { return nil } - + // Group findings by resource byResource := make(map[string][]*Finding) for _, f := range i.findings.GetActive(FindingSeverityInfo) { - if f == nil { - continue - } byResource[f.ResourceID] = append(byResource[f.ResourceID], f) } - + // Calculate risk for each resource type resourceRisk struct { id string @@ -819,13 +819,9 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary { score float64 top string } - + var risks []resourceRisk for id, findings := range byResource { - if len(findings) == 0 { - continue - } - score := 0.0 var topFinding *Finding for _, f := range findings { @@ -843,7 +839,7 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary { topFinding = f } } - + if score > 0 && topFinding != nil { risks = append(risks, resourceRisk{ id: id, @@ -854,16 +850,16 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary { }) } } - + // Sort by risk score descending sort.Slice(risks, func(a, b int) bool { return risks[a].score > risks[b].score }) - + if len(risks) > limit { risks = risks[:limit] } - + // Convert to summaries var summaries []ResourceRiskSummary for _, r := range risks { @@ -879,11 +875,14 @@ func (i *Intelligence) getResourcesAtRisk(limit int) []ResourceRiskSummary { TopIssue: r.top, }) } - + return summaries } func (i *Intelligence) detectCurrentAnomalies(resourceID string) []AnomalyReport { + if i.anomalyDetector != nil { + return i.anomalyDetector(resourceID) + } // This would be called with current metrics from state // For now, return empty - will be integrated with patrol return nil @@ -893,21 +892,21 @@ func (i *Intelligence) formatBaselinesForContext(resourceID string) string { if i.baselines == nil { return "" } - + rb, ok := i.baselines.GetResourceBaseline(resourceID) if !ok || len(rb.Metrics) == 0 { return "" } - + var lines []string lines = append(lines, "\n## Learned Baselines") lines = append(lines, "Normal operating ranges for this resource:") - + for metric, mb := range rb.Metrics { lines = append(lines, fmt.Sprintf("- %s: mean %.1f, stddev %.1f (samples: %d)", metric, mb.Mean, mb.StdDev, mb.SampleCount)) } - + return strings.Join(lines, "\n") } @@ -915,15 +914,15 @@ func (i *Intelligence) formatAnomaliesForContext(anomalies []AnomalyReport) stri if len(anomalies) == 0 { return "" } - + var lines []string lines = append(lines, "\n## Current Anomalies") lines = append(lines, "Metrics deviating from normal:") - + for _, a := range anomalies { lines = append(lines, fmt.Sprintf("- %s: %s", a.Metric, a.Description)) } - + return strings.Join(lines, "\n") } diff --git a/internal/ai/intelligence_coverage_test.go b/internal/ai/intelligence_coverage_test.go new file mode 100644 index 000000000..318efdde6 --- /dev/null +++ b/internal/ai/intelligence_coverage_test.go @@ -0,0 +1,552 @@ +package ai + +import ( + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/baseline" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/correlation" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/memory" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/patterns" + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" +) + +func TestIntelligence_formatBaselinesForContext(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + store := baseline.NewStore(baseline.StoreConfig{MinSamples: 1}) + if err := store.Learn("res-1", "vm", "cpu", []baseline.MetricPoint{{Value: 12.5}}); err != nil { + t.Fatalf("Learn: %v", err) + } + intel.SetSubsystems(nil, nil, nil, store, nil, nil, nil, nil) + + ctx := intel.formatBaselinesForContext("res-1") + if !strings.Contains(ctx, "Learned Baselines") { + t.Fatalf("expected baseline header, got %q", ctx) + } + if !strings.Contains(ctx, "cpu: mean") { + t.Fatalf("expected cpu baseline line, got %q", ctx) + } + + if got := intel.formatBaselinesForContext("missing"); got != "" { + t.Errorf("expected empty context for missing baseline, got %q", got) + } + + empty := NewIntelligence(IntelligenceConfig{}) + if got := empty.formatBaselinesForContext("res-1"); got != "" { + t.Errorf("expected empty context with no baseline store, got %q", got) + } +} + +func TestIntelligence_formatAnomaliesForContext(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + if got := intel.formatAnomaliesForContext(nil); got != "" { + t.Errorf("expected empty anomalies context, got %q", got) + } + + anomalies := []AnomalyReport{{Metric: "cpu", Description: "CPU high"}} + ctx := intel.formatAnomaliesForContext(anomalies) + if !strings.Contains(ctx, "Current Anomalies") { + t.Fatalf("expected anomalies header, got %q", ctx) + } + if !strings.Contains(ctx, "CPU high") { + t.Fatalf("expected anomaly description, got %q", ctx) + } +} + +func TestIntelligence_formatAnomalyDescription(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + bl := &baseline.MetricBaseline{Mean: 10} + + above := intel.formatAnomalyDescription("cpu", 20, bl, 2.5) + if !strings.Contains(above, "above baseline") { + t.Errorf("expected above-baseline description, got %q", above) + } + + below := intel.formatAnomalyDescription("cpu", 5, bl, -1.5) + if !strings.Contains(below, "below baseline") { + t.Errorf("expected below-baseline description, got %q", below) + } +} + +func TestIntelligence_FormatContext_Anomalies(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + intel.anomalyDetector = func(resourceID string) []AnomalyReport { + return []AnomalyReport{{Metric: "cpu", Description: "CPU high"}} + } + + ctx := intel.FormatContext("res-1") + if !strings.Contains(ctx, "Current Anomalies") { + t.Fatalf("expected anomalies context, got %q", ctx) + } + if !strings.Contains(ctx, "CPU high") { + t.Fatalf("expected anomaly description, got %q", ctx) + } +} + +func TestIntelligence_getUpcomingRisks(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + predictions := []patterns.FailurePrediction{ + {ResourceID: "r1", EventType: patterns.EventHighCPU, DaysUntil: 2, Confidence: 0.6}, + {ResourceID: "r2", EventType: patterns.EventHighMemory, DaysUntil: 9, Confidence: 0.9}, + {ResourceID: "r3", EventType: patterns.EventDiskFull, DaysUntil: 4, Confidence: 0.4}, + {ResourceID: "r4", EventType: patterns.EventOOM, DaysUntil: 1, Confidence: 0.8}, + {ResourceID: "r5", EventType: patterns.EventRestart, DaysUntil: 5, Confidence: 0.7}, + } + + risk := intel.getUpcomingRisks(predictions, 2) + if len(risk) != 2 { + t.Fatalf("expected 2 risks, got %d", len(risk)) + } + if risk[0].ResourceID != "r4" || risk[1].ResourceID != "r1" { + t.Errorf("unexpected ordering: %v", []string{risk[0].ResourceID, risk[1].ResourceID}) + } +} + +func TestIntelligence_countFindings(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + findings := []*Finding{ + nil, + {Severity: FindingSeverityCritical}, + {Severity: FindingSeverityWarning}, + {Severity: FindingSeverityWatch}, + {Severity: FindingSeverityInfo}, + } + counts := intel.countFindings(findings) + if counts.Total != 4 { + t.Errorf("expected total 4, got %d", counts.Total) + } + if counts.Critical != 1 || counts.Warning != 1 || counts.Watch != 1 || counts.Info != 1 { + t.Errorf("unexpected counts: %+v", counts) + } +} + +func TestIntelligence_getTopFindings(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + base := time.Now() + findings := []*Finding{ + {Severity: FindingSeverityWarning, DetectedAt: base.Add(-2 * time.Hour)}, + {Severity: FindingSeverityCritical, DetectedAt: base.Add(-3 * time.Hour), Title: "older critical"}, + {Severity: FindingSeverityCritical, DetectedAt: base.Add(-1 * time.Hour), Title: "newer critical"}, + {Severity: FindingSeverityWatch, DetectedAt: base.Add(-30 * time.Minute)}, + } + + top := intel.getTopFindings(findings, 3) + if len(top) != 3 { + t.Fatalf("expected 3 findings, got %d", len(top)) + } + if top[0].Title != "newer critical" || top[1].Title != "older critical" { + t.Errorf("unexpected ordering: %q, %q", top[0].Title, top[1].Title) + } + if top[2].Severity != FindingSeverityWarning { + t.Errorf("expected warning in third position, got %s", top[2].Severity) + } +} + +func TestIntelligence_getTopFindings_Empty(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + if got := intel.getTopFindings(nil, 5); got != nil { + t.Errorf("expected nil for empty findings, got %v", got) + } +} + +func TestIntelligence_getLearningStats(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + knowledgeStore, err := knowledge.NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + knowledgeStore.SaveNote("vm-1", "vm-1", "vm", "general", "Note", "Content") + knowledgeStore.SaveNote("vm-2", "vm-2", "vm", "general", "Note", "Content") + + baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1}) + baselineStore.Learn("vm-1", "vm", "cpu", []baseline.MetricPoint{{Value: 10}}) + + patternDetector := patterns.NewDetector(patterns.DetectorConfig{ + MinOccurrences: 2, + PatternWindow: 48 * time.Hour, + PredictionLimit: 30 * 24 * time.Hour, + }) + patternStart := time.Now().Add(-90 * time.Minute) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-1", EventType: patterns.EventHighCPU, Timestamp: patternStart}) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-1", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)}) + + correlationDetector := correlation.NewDetector(correlation.Config{ + MinOccurrences: 1, + CorrelationWindow: 2 * time.Hour, + RetentionWindow: 24 * time.Hour, + }) + corrStart := time.Now().Add(-30 * time.Minute) + for i := 0; i < 2; i++ { + base := corrStart.Add(time.Duration(i) * 10 * time.Minute) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-a", ResourceName: "node-a", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: base}) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-b", ResourceName: "vm-b", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: base.Add(1 * time.Minute)}) + } + + incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 5}) + + intel.SetSubsystems(nil, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil) + + stats := intel.getLearningStats() + if stats.ResourcesWithKnowledge != 2 { + t.Errorf("expected 2 resources with knowledge, got %d", stats.ResourcesWithKnowledge) + } + if stats.TotalNotes != 2 { + t.Errorf("expected 2 total notes, got %d", stats.TotalNotes) + } + if stats.ResourcesWithBaselines != 1 { + t.Errorf("expected 1 resource with baseline, got %d", stats.ResourcesWithBaselines) + } + if stats.PatternsDetected != 1 { + t.Errorf("expected 1 pattern, got %d", stats.PatternsDetected) + } + if stats.CorrelationsLearned != 1 { + t.Errorf("expected 1 correlation, got %d", stats.CorrelationsLearned) + } +} + +func TestIntelligence_getResourcesAtRisk(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + findings := NewFindingsStore() + findings.Add(&Finding{ + ID: "crit-1", + Key: "crit-1", + Severity: FindingSeverityCritical, + Category: FindingCategoryReliability, + ResourceID: "res-critical", + ResourceName: "critical-vm", + ResourceType: "vm", + Title: "Critical outage", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + Source: "test", + }) + findings.Add(&Finding{ + ID: "warn-1", + Key: "warn-1", + Severity: FindingSeverityWarning, + Category: FindingCategoryPerformance, + ResourceID: "res-warning", + ResourceName: "warning-vm", + ResourceType: "vm", + Title: "Warning issue", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + Source: "test", + }) + findings.Add(&Finding{ + ID: "watch-1", + Key: "watch-1", + Severity: FindingSeverityWatch, + Category: FindingCategoryPerformance, + ResourceID: "res-warning", + ResourceName: "warning-vm", + ResourceType: "vm", + Title: "Watch issue", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + Source: "test", + }) + findings.Add(&Finding{ + ID: "info-1", + Key: "info-1", + Severity: FindingSeverityInfo, + Category: FindingCategoryPerformance, + ResourceID: "res-warning", + ResourceName: "warning-vm", + ResourceType: "vm", + Title: "Info issue", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + Source: "test", + }) + + intel.SetSubsystems(findings, nil, nil, nil, nil, nil, nil, nil) + risks := intel.getResourcesAtRisk(1) + if len(risks) != 1 { + t.Fatalf("expected 1 risk, got %d", len(risks)) + } + if risks[0].ResourceID != "res-critical" { + t.Errorf("expected critical resource first, got %s", risks[0].ResourceID) + } + if risks[0].TopIssue != "Critical outage" { + t.Errorf("expected top issue to be critical, got %s", risks[0].TopIssue) + } +} + +func TestIntelligence_calculateResourceHealth_ClampAndSeverities(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + + resourceIntel := &ResourceIntelligence{ + ResourceID: "test-vm", + ActiveFindings: []*Finding{ + nil, + {Severity: FindingSeverityCritical, Title: "crit-1"}, + {Severity: FindingSeverityCritical, Title: "crit-2"}, + {Severity: FindingSeverityCritical, Title: "crit-3"}, + {Severity: FindingSeverityCritical, Title: "crit-4"}, + {Severity: FindingSeverityWatch, Title: "watch"}, + {Severity: FindingSeverityInfo, Title: "info"}, + }, + Anomalies: []AnomalyReport{ + {Metric: "cpu", Severity: baseline.AnomalyHigh, Description: "high"}, + {Metric: "disk", Severity: baseline.AnomalyLow, Description: "low"}, + }, + } + + health := intel.calculateResourceHealth(resourceIntel) + if health.Score != 0 { + t.Errorf("expected score clamped to 0, got %f", health.Score) + } +} + +func TestIntelligence_calculateOverallHealth_Clamps(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + var predictions []patterns.FailurePrediction + for i := 0; i < 10; i++ { + predictions = append(predictions, patterns.FailurePrediction{EventType: patterns.EventHighCPU, DaysUntil: 1, Confidence: 0.9}) + } + negative := intel.calculateOverallHealth(&IntelligenceSummary{ + FindingsCount: FindingsCounts{Critical: 10, Warning: 10}, + UpcomingRisks: predictions, + }) + if negative.Score != 0 { + t.Errorf("expected score clamped to 0, got %f", negative.Score) + } + + positive := intel.calculateOverallHealth(&IntelligenceSummary{ + Learning: LearningStats{ResourcesWithKnowledge: 10}, + }) + if positive.Score != 100 { + t.Errorf("expected score clamped to 100, got %f", positive.Score) + } + if len(positive.Factors) == 0 { + t.Error("expected learning factor for positive health") + } +} + +func TestIntelligence_generateHealthPrediction_Branches(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + + gradeA := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeA}, &IntelligenceSummary{}) + if !strings.Contains(gradeA, "healthy") { + t.Errorf("expected healthy prediction, got %q", gradeA) + } + + critical := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeB}, &IntelligenceSummary{ + FindingsCount: FindingsCounts{Critical: 2}, + }) + if !strings.Contains(critical, "Immediate attention") { + t.Errorf("expected critical prediction, got %q", critical) + } + + risk := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeB}, &IntelligenceSummary{ + UpcomingRisks: []patterns.FailurePrediction{{EventType: patterns.EventHighCPU, DaysUntil: 2, Confidence: 0.8}}, + }) + if !strings.Contains(risk, "Predicted") { + t.Errorf("expected prediction text, got %q", risk) + } + + stable := intel.generateHealthPrediction(HealthScore{Grade: HealthGradeC}, &IntelligenceSummary{}) + if !strings.Contains(stable, "stable") { + t.Errorf("expected stable prediction, got %q", stable) + } +} + +func TestIntelligence_FormatContext_AllSubsystems(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + knowledgeStore, _ := knowledge.NewStore(t.TempDir()) + knowledgeStore.SaveNote("vm-ctx", "vm-ctx", "vm", "general", "Note", "Content") + + baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1}) + baselineStore.Learn("vm-ctx", "vm", "cpu", []baseline.MetricPoint{{Value: 10}}) + + patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour}) + patternStart := time.Now().Add(-90 * time.Minute) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-ctx", EventType: patterns.EventHighCPU, Timestamp: patternStart}) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-ctx", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)}) + + correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour}) + corrBase := time.Now().Add(-30 * time.Minute) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-ctx", ResourceName: "vm-ctx", ResourceType: "vm", EventType: correlation.EventHighCPU, Timestamp: corrBase}) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-ctx", ResourceName: "node-ctx", ResourceType: "node", EventType: correlation.EventRestart, Timestamp: corrBase.Add(1 * time.Minute)}) + + incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10}) + incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-ctx", ResourceID: "vm-ctx", ResourceName: "vm-ctx", Type: "cpu", StartTime: time.Now()}) + + intel.SetSubsystems(nil, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil) + ctx := intel.FormatContext("vm-ctx") + if !strings.Contains(ctx, "Learned Baselines") { + t.Fatalf("expected baseline context, got %q", ctx) + } + if !strings.Contains(ctx, "Failure Predictions") { + t.Fatalf("expected predictions context, got %q", ctx) + } + if !strings.Contains(ctx, "Resource Correlations") { + t.Fatalf("expected correlations context, got %q", ctx) + } + if !strings.Contains(ctx, "Incident Memory") { + t.Fatalf("expected incident context, got %q", ctx) + } +} + +func TestIntelligence_FormatGlobalContext_AllSubsystems(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + knowledgeStore, _ := knowledge.NewStore(t.TempDir()) + knowledgeStore.SaveNote("vm-global", "vm-global", "vm", "general", "Note", "Content") + + incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10}) + incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-global", ResourceID: "vm-global", ResourceName: "vm-global", Type: "cpu", StartTime: time.Now()}) + + correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour}) + corrBase := time.Now().Add(-30 * time.Minute) + for i := 0; i < 2; i++ { + base := corrBase.Add(time.Duration(i) * 10 * time.Minute) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-a", ResourceName: "node-a", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: base}) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-global", ResourceName: "vm-global", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: base.Add(1 * time.Minute)}) + } + + patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour}) + patternStart := time.Now().Add(-90 * time.Minute) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-global", EventType: patterns.EventHighCPU, Timestamp: patternStart}) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-global", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)}) + + intel.SetSubsystems(nil, patternDetector, correlationDetector, nil, incidentStore, knowledgeStore, nil, nil) + ctx := intel.FormatGlobalContext() + if ctx == "" { + t.Fatal("expected non-empty global context") + } + if !strings.Contains(ctx, "Resource Correlations") { + t.Fatalf("expected correlations in global context, got %q", ctx) + } +} + +func TestIntelligence_GetSummary_WithSubsystems(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + findings := NewFindingsStore() + findings.Add(&Finding{ + ID: "f1", + Key: "f1", + Severity: FindingSeverityWarning, + Category: FindingCategoryPerformance, + ResourceID: "vm-sum", + ResourceName: "vm-sum", + ResourceType: "vm", + Title: "Warning", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + Source: "test", + }) + + patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour}) + patternStart := time.Now().Add(-90 * time.Minute) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-sum", EventType: patterns.EventHighCPU, Timestamp: patternStart}) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-sum", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)}) + + changes := memory.NewChangeDetector(memory.ChangeDetectorConfig{MaxChanges: 10}) + changes.DetectChanges([]memory.ResourceSnapshot{{ID: "vm-sum", Name: "vm-sum", Type: "vm", Status: "running", SnapshotTime: time.Now()}}) + + remediations := memory.NewRemediationLog(memory.RemediationLogConfig{MaxRecords: 10}) + remediations.Log(memory.RemediationRecord{ResourceID: "vm-sum", Problem: "cpu", Action: "restart", Outcome: memory.OutcomeResolved}) + + intel.SetSubsystems(findings, patternDetector, nil, nil, nil, nil, changes, remediations) + + summary := intel.GetSummary() + if summary.FindingsCount.Total == 0 { + t.Error("expected findings in summary") + } + if summary.PredictionsCount == 0 { + t.Error("expected predictions in summary") + } + if len(summary.UpcomingRisks) == 0 { + t.Error("expected upcoming risks in summary") + } + if summary.RecentChangesCount == 0 { + t.Error("expected recent changes in summary") + } + if len(summary.RecentRemediations) == 0 { + t.Error("expected recent remediations in summary") + } + if len(summary.ResourcesAtRisk) == 0 { + t.Error("expected resources at risk in summary") + } +} + +func TestIntelligence_GetResourceIntelligence_WithAllSubsystems(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + findings := NewFindingsStore() + findings.Add(&Finding{ + ID: "f1", + Key: "f1", + Severity: FindingSeverityWarning, + Category: FindingCategoryPerformance, + ResourceID: "vm-intel", + ResourceName: "vm-intel", + ResourceType: "vm", + Title: "Warning", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + Source: "test", + }) + + patternDetector := patterns.NewDetector(patterns.DetectorConfig{MinOccurrences: 2, PatternWindow: 48 * time.Hour, PredictionLimit: 30 * 24 * time.Hour}) + patternStart := time.Now().Add(-90 * time.Minute) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-intel", EventType: patterns.EventHighCPU, Timestamp: patternStart}) + patternDetector.RecordEvent(patterns.HistoricalEvent{ResourceID: "vm-intel", EventType: patterns.EventHighCPU, Timestamp: patternStart.Add(60 * time.Minute)}) + + correlationDetector := correlation.NewDetector(correlation.Config{MinOccurrences: 1, CorrelationWindow: 2 * time.Hour, RetentionWindow: 24 * time.Hour}) + corrBase := time.Now().Add(-30 * time.Minute) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "node-intel", ResourceName: "node-intel", ResourceType: "node", EventType: correlation.EventHighCPU, Timestamp: corrBase}) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-intel", ResourceName: "vm-intel", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: corrBase.Add(1 * time.Minute)}) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-intel", ResourceName: "vm-intel", ResourceType: "vm", EventType: correlation.EventHighCPU, Timestamp: corrBase.Add(2 * time.Minute)}) + correlationDetector.RecordEvent(correlation.Event{ResourceID: "vm-child", ResourceName: "vm-child", ResourceType: "vm", EventType: correlation.EventRestart, Timestamp: corrBase.Add(3 * time.Minute)}) + + baselineStore := baseline.NewStore(baseline.StoreConfig{MinSamples: 1}) + baselineStore.Learn("vm-intel", "vm", "cpu", []baseline.MetricPoint{{Value: 10}}) + + incidentStore := memory.NewIncidentStore(memory.IncidentStoreConfig{MaxIncidents: 10}) + incidentStore.RecordAlertFired(&alerts.Alert{ID: "alert-intel", ResourceID: "vm-intel", ResourceName: "vm-intel", Type: "cpu", StartTime: time.Now()}) + + knowledgeStore, _ := knowledge.NewStore(t.TempDir()) + knowledgeStore.SaveNote("vm-intel", "vm-intel", "vm", "general", "Note", "Content") + + intel.SetSubsystems(findings, patternDetector, correlationDetector, baselineStore, incidentStore, knowledgeStore, nil, nil) + res := intel.GetResourceIntelligence("vm-intel") + if len(res.ActiveFindings) == 0 { + t.Error("expected active findings") + } + if len(res.Predictions) == 0 { + t.Error("expected predictions") + } + if len(res.Correlations) == 0 { + t.Error("expected correlations") + } + if len(res.Dependents) == 0 || len(res.Dependencies) == 0 { + t.Error("expected dependencies and dependents") + } + if len(res.Baselines) == 0 { + t.Error("expected baselines") + } + if len(res.RecentIncidents) == 0 { + t.Error("expected incidents") + } + if res.Knowledge == nil || res.NoteCount == 0 { + t.Error("expected knowledge") + } +} + +func TestIntelligence_GetResourceIntelligence_KnowledgeFallback(t *testing.T) { + intel := NewIntelligence(IntelligenceConfig{}) + knowledgeStore, _ := knowledge.NewStore(t.TempDir()) + knowledgeStore.SaveNote("vm-know", "knowledge-vm", "vm", "general", "Note", "Content") + + intel.SetSubsystems(nil, nil, nil, nil, nil, knowledgeStore, nil, nil) + res := intel.GetResourceIntelligence("vm-know") + if res.ResourceName != "knowledge-vm" { + t.Errorf("expected resource name from knowledge, got %q", res.ResourceName) + } + if res.ResourceType != "vm" { + t.Errorf("expected resource type from knowledge, got %q", res.ResourceType) + } +} diff --git a/internal/ai/memory/incidents.go b/internal/ai/memory/incidents.go index 0725c6334..16a503eff 100644 --- a/internal/ai/memory/incidents.go +++ b/internal/ai/memory/incidents.go @@ -167,9 +167,6 @@ func (s *IncidentStore) RecordAlertAcknowledged(alert *alerts.Alert, user string defer s.mu.Unlock() incident := s.ensureIncidentForAlertLocked(alert) - if incident == nil { - return - } incident.Acknowledged = true if alert.AckTime != nil { @@ -198,9 +195,6 @@ func (s *IncidentStore) RecordAlertUnacknowledged(alert *alerts.Alert, user stri defer s.mu.Unlock() incident := s.ensureIncidentForAlertLocked(alert) - if incident == nil { - return - } incident.Acknowledged = false incident.AckTime = nil diff --git a/internal/ai/memory/memory_coverage_test.go b/internal/ai/memory/memory_coverage_test.go new file mode 100644 index 000000000..cf8002948 --- /dev/null +++ b/internal/ai/memory/memory_coverage_test.go @@ -0,0 +1,1128 @@ +package memory + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" +) + +func TestChangeDetector_DefaultsAndHelpers(t *testing.T) { + detector := NewChangeDetector(ChangeDetectorConfig{}) + if detector.maxChanges != 1000 { + t.Fatalf("expected default maxChanges=1000, got %d", detector.maxChanges) + } + + if got := intToString(0); got != "0" { + t.Errorf("intToString(0) = %q", got) + } + if got := intToString(42); got != "42" { + t.Errorf("intToString(42) = %q", got) + } + if got := formatFloat(2.0); got != "2" { + t.Errorf("formatFloat(2.0) = %q", got) + } + if got := formatFloat(2.5); got != "2.5" { + t.Errorf("formatFloat(2.5) = %q", got) + } + + cpu := formatCPUChangeDescription("vm-1", 4, 2) + if !strings.Contains(cpu, "decreased") { + t.Errorf("expected cpu decrease description, got %q", cpu) + } + mem := formatMemoryChangeDescription("vm-1", 8<<30, 4<<30) + if !strings.Contains(mem, "decreased") { + t.Errorf("expected memory decrease description, got %q", mem) + } +} + +func TestNewChangeDetector_LoadsFromDisk(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_changes.json") + changes := []Change{ + {ID: "c1", DetectedAt: time.Now().Add(-2 * time.Hour)}, + } + data, err := json.Marshal(changes) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + detector := NewChangeDetector(ChangeDetectorConfig{ + MaxChanges: 10, + DataDir: tmpDir, + }) + if len(detector.changes) != 1 { + t.Fatalf("expected 1 change loaded, got %d", len(detector.changes)) + } +} + +func TestNewChangeDetector_LoadError(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_changes.json") + if err := os.WriteFile(path, []byte("{"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + detector := NewChangeDetector(ChangeDetectorConfig{ + MaxChanges: 10, + DataDir: tmpDir, + }) + if len(detector.changes) != 0 { + t.Fatalf("expected no changes after load error, got %d", len(detector.changes)) + } +} + +func TestChangeDetector_SaveToDisk_Scenarios(t *testing.T) { + t.Run("NoDataDir", func(t *testing.T) { + d := &ChangeDetector{} + if err := d.saveToDisk(); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("MissingDir", func(t *testing.T) { + tmpDir := t.TempDir() + missing := filepath.Join(tmpDir, "missing") + d := &ChangeDetector{ + dataDir: missing, + changes: []Change{{ID: "c1"}}, + } + if err := d.saveToDisk(); err == nil { + t.Fatal("expected error for missing directory") + } + }) + + t.Run("MarshalError", func(t *testing.T) { + tmpDir := t.TempDir() + d := &ChangeDetector{ + dataDir: tmpDir, + changes: []Change{{ID: "c1", Before: func() {}}}, + } + if err := d.saveToDisk(); err == nil { + t.Fatal("expected marshal error") + } + }) + + t.Run("RenameError", func(t *testing.T) { + tmpDir := t.TempDir() + destDir := filepath.Join(tmpDir, "ai_changes.json") + if err := os.MkdirAll(destDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + d := &ChangeDetector{ + dataDir: tmpDir, + changes: []Change{{ID: "c1"}}, + } + if err := d.saveToDisk(); err == nil { + t.Fatal("expected rename error") + } + }) + + t.Run("Success", func(t *testing.T) { + tmpDir := t.TempDir() + d := &ChangeDetector{ + dataDir: tmpDir, + changes: []Change{{ID: "c1"}}, + } + if err := d.saveToDisk(); err != nil { + t.Fatalf("saveToDisk error: %v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "ai_changes.json")); err != nil { + t.Fatalf("expected file to exist: %v", err) + } + }) +} + +func TestChangeDetector_LoadFromDisk_Scenarios(t *testing.T) { + t.Run("NoFile", func(t *testing.T) { + d := &ChangeDetector{dataDir: t.TempDir()} + if err := d.loadFromDisk(); err != nil { + t.Fatalf("expected nil error for missing file, got %v", err) + } + }) + + t.Run("EmptyDataDir", func(t *testing.T) { + d := &ChangeDetector{dataDir: ""} + if err := d.loadFromDisk(); err != nil { + t.Fatalf("expected nil error for empty dataDir, got %v", err) + } + }) + + t.Run("ReadError", func(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "not-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + d := &ChangeDetector{dataDir: filePath} + if err := d.loadFromDisk(); err == nil { + t.Fatalf("expected read error") + } + }) + + t.Run("InvalidJSON", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_changes.json") + if err := os.WriteFile(path, []byte("{"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + d := &ChangeDetector{dataDir: tmpDir} + if err := d.loadFromDisk(); err == nil { + t.Fatal("expected JSON error") + } + }) + + t.Run("TooLarge", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_changes.json") + file, err := os.Create(path) + if err != nil { + t.Fatalf("create: %v", err) + } + if err := file.Truncate(10<<20 + 1); err != nil { + t.Fatalf("truncate: %v", err) + } + if err := file.Close(); err != nil { + t.Fatalf("close: %v", err) + } + + d := &ChangeDetector{dataDir: tmpDir} + if err := d.loadFromDisk(); err == nil { + t.Fatal("expected size error") + } + }) + + t.Run("SortedAndTrimmed", func(t *testing.T) { + tmpDir := t.TempDir() + now := time.Now() + changes := []Change{ + {ID: "c2", DetectedAt: now.Add(-1 * time.Hour)}, + {ID: "c1", DetectedAt: now.Add(-2 * time.Hour)}, + {ID: "c3", DetectedAt: now.Add(-10 * time.Minute)}, + } + data, err := json.Marshal(changes) + if err != nil { + t.Fatalf("marshal: %v", err) + } + path := filepath.Join(tmpDir, "ai_changes.json") + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + d := &ChangeDetector{ + dataDir: tmpDir, + maxChanges: 2, + } + if err := d.loadFromDisk(); err != nil { + t.Fatalf("loadFromDisk error: %v", err) + } + if len(d.changes) != 2 { + t.Fatalf("expected 2 changes after trim, got %d", len(d.changes)) + } + if !d.changes[0].DetectedAt.Equal(changes[0].DetectedAt) { + t.Fatalf("expected oldest remaining change to be c2") + } + }) +} + +func TestChangeDetector_DetectChanges_SaveError(t *testing.T) { + tmpDir := t.TempDir() + badDir := filepath.Join(tmpDir, "not-dir") + if err := os.WriteFile(badDir, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + detector := NewChangeDetector(ChangeDetectorConfig{DataDir: badDir}) + detector.DetectChanges([]ResourceSnapshot{ + {ID: "vm-1", Name: "vm-1", Type: "vm", Status: "running"}, + }) + time.Sleep(20 * time.Millisecond) +} + +func TestIncidentStore_DefaultsAndSummary(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + if store.maxIncidents != defaultIncidentMaxIncidents { + t.Fatalf("expected default max incidents, got %d", store.maxIncidents) + } + if store.maxEvents != defaultIncidentMaxEvents { + t.Fatalf("expected default max events, got %d", store.maxEvents) + } + if store.maxAge != time.Duration(defaultIncidentMaxAgeDays)*24*time.Hour { + t.Fatalf("expected default max age, got %v", store.maxAge) + } + + if got := formatAlertSummary(nil); got != "Alert triggered" { + t.Fatalf("unexpected nil summary %q", got) + } + noValue := formatAlertSummary(&alerts.Alert{Type: "cpu", Level: alerts.AlertLevelWarning}) + if !strings.Contains(noValue, "Alert triggered: cpu (warning)") { + t.Fatalf("unexpected summary %q", noValue) + } + withValue := formatAlertSummary(&alerts.Alert{ + Type: "cpu", + Level: alerts.AlertLevelCritical, + Value: 90, + Threshold: 80, + }) + if !strings.Contains(withValue, ">= 80.0") { + t.Fatalf("expected threshold summary, got %q", withValue) + } +} + +func TestNewIncidentStore_LoadsFromDisk(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, incidentFileName) + incidents := []*Incident{ + {ID: "inc-1", AlertID: "alert-1", Status: IncidentStatusOpen, OpenedAt: time.Now()}, + } + data, err := json.Marshal(incidents) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + store := NewIncidentStore(IncidentStoreConfig{DataDir: tmpDir}) + if len(store.incidents) != 1 { + t.Fatalf("expected incidents loaded, got %d", len(store.incidents)) + } +} + +func TestNewIncidentStore_LoadError(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, incidentFileName) + if err := os.WriteFile(path, []byte("{"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + store := NewIncidentStore(IncidentStoreConfig{DataDir: tmpDir}) + if len(store.incidents) != 0 { + t.Fatalf("expected no incidents after load error") + } +} + +func TestIncidentStore_RecordAlertFired_Existing(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + store.RecordAlertFired(nil) + + alert := &alerts.Alert{ + ID: "alert-fired", + Type: "cpu", + Level: alerts.AlertLevelWarning, + ResourceID: "res-1", + ResourceName: "vm-1", + Message: "original", + } + store.RecordAlertFired(alert) + alert.Message = "updated" + store.RecordAlertFired(alert) + + timeline := store.GetTimelineByAlertID(alert.ID) + if timeline == nil { + t.Fatalf("expected timeline") + } + if timeline.Message != "updated" { + t.Fatalf("expected updated message, got %q", timeline.Message) + } + if len(timeline.Events) != 1 { + t.Fatalf("expected 1 event, got %d", len(timeline.Events)) + } +} + +func TestIncidentStore_RecordAlertAcknowledged_WithAckTime(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + ackTime := time.Now().Add(-5 * time.Minute) + alert := &alerts.Alert{ + ID: "alert-ack", + Type: "memory", + Level: alerts.AlertLevelWarning, + AckTime: &ackTime, + } + store.RecordAlertAcknowledged(nil, "user") + store.RecordAlertAcknowledged(alert, "user") + + timeline := store.GetTimelineByAlertID(alert.ID) + if timeline == nil { + t.Fatalf("expected timeline") + } + if timeline.AckTime == nil || !timeline.AckTime.Equal(ackTime) { + t.Fatalf("expected ack time to match") + } + if timeline.AckUser != "user" { + t.Fatalf("expected ack user") + } +} + +func TestIncidentStore_RecordAlertResolved_ZeroTime(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + alert := &alerts.Alert{ + ID: "alert-resolved", + Type: "disk", + Level: alerts.AlertLevelCritical, + } + store.RecordAlertResolved(nil, time.Time{}) + store.RecordAlertResolved(alert, time.Time{}) + + timeline := store.GetTimelineByAlertID(alert.ID) + if timeline == nil || timeline.ClosedAt == nil { + t.Fatalf("expected closed incident") + } +} + +func TestIncidentStore_RecordAnalysis_CommandDetails(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + store.RecordAnalysis("", "", nil) + store.RecordAnalysis("alert-analysis", "", nil) + + timeline := store.GetTimelineByAlertID("alert-analysis") + if timeline == nil { + t.Fatalf("expected timeline") + } + if len(timeline.Events) == 0 || timeline.Events[0].Summary != "AI analysis completed" { + t.Fatalf("expected default analysis summary") + } + + store.RecordCommand("", "", false, "", nil) + store.RecordCommand("alert-cmd", "echo test", false, "", nil) + + cmdTimeline := store.GetTimelineByAlertID("alert-cmd") + if cmdTimeline == nil || len(cmdTimeline.Events) == 0 { + t.Fatalf("expected command event") + } + if _, ok := cmdTimeline.Events[0].Details["output_excerpt"]; ok { + t.Fatalf("did not expect output_excerpt for empty output") + } +} + +func TestIncidentStore_Timelines_EmptyAndZeroTime(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + if got := store.GetTimelineByAlertID(""); got != nil { + t.Fatalf("expected nil for empty alert ID") + } + if got := store.GetTimelineByAlertAt("", time.Now()); got != nil { + t.Fatalf("expected nil for empty alert ID in GetTimelineByAlertAt") + } + if got := store.GetTimelineByAlertAt("missing-alert", time.Now()); got != nil { + t.Fatalf("expected nil for missing alert") + } + + alert := &alerts.Alert{ + ID: "alert-zero", + Type: "cpu", + Level: alerts.AlertLevelWarning, + ResourceName: "vm-1", + StartTime: time.Now().Add(-10 * time.Minute), + } + store.RecordAlertFired(alert) + + timeline := store.GetTimelineByAlertAt(alert.ID, time.Time{}) + if timeline == nil || timeline.AlertID != alert.ID { + t.Fatalf("expected timeline for zero start time") + } + timeline = store.GetTimelineByAlertAt(alert.ID, alert.StartTime.Add(5*time.Minute)) + if timeline == nil || timeline.AlertID != alert.ID { + t.Fatalf("expected timeline for later start time") + } +} + +func TestIncidentStore_GetTimelineByAlertAt_SkipsMismatched(t *testing.T) { + store := &IncidentStore{ + incidents: []*Incident{ + nil, + {ID: "inc-a", AlertID: "alert-a", OpenedAt: time.Now().Add(-10 * time.Minute)}, + {ID: "inc-b", AlertID: "alert-b", OpenedAt: time.Now().Add(-5 * time.Minute)}, + }, + } + + timeline := store.GetTimelineByAlertAt("alert-b", time.Now()) + if timeline == nil || timeline.AlertID != "alert-b" { + t.Fatalf("expected timeline for alert-b") + } +} + +func TestIncidentStore_FormatForPatrol_MessageFallback(t *testing.T) { + store := NewIncidentStore(IncidentStoreConfig{}) + store.incidents = append(store.incidents, &Incident{ + ID: "inc-message", + AlertID: "alert-message", + Status: IncidentStatusOpen, + OpenedAt: time.Now(), + Message: "fallback message", + }) + store.incidents = append(store.incidents, &Incident{ + ID: "inc-ack", + AlertID: "alert-ack", + Status: IncidentStatusOpen, + OpenedAt: time.Now().Add(1 * time.Minute), + Acknowledged: true, + ResourceName: "vm-ack", + AlertType: "cpu", + }) + store.incidents = append(store.incidents, nil) + + result := store.FormatForPatrol(2) + if !strings.Contains(result, "fallback message") { + t.Fatalf("expected message fallback in patrol output") + } + if !strings.Contains(result, "acknowledged") { + t.Fatalf("expected acknowledged status in patrol output") + } +} + +func TestIncidentStore_HelperPaths(t *testing.T) { + store := &IncidentStore{ + incidents: make([]*Incident, 0), + maxEvents: 1, + maxIncidents: 1, + maxAge: 30 * time.Minute, + } + + alert := &alerts.Alert{ + ID: "alert-helper", + Type: "cpu", + Level: alerts.AlertLevelWarning, + } + + incident := store.ensureIncidentForAlertLocked(alert) + if incident == nil || len(store.incidents) != 1 { + t.Fatalf("expected incident created") + } + store.ensureIncidentForAlertLocked(alert) + if len(store.incidents) != 1 { + t.Fatalf("expected same incident to be reused") + } + + updateIncidentFromAlert(nil, alert) + updateIncidentFromAlert(incident, nil) + + store.addEventLocked(nil, IncidentEventAnalysis, "", nil) + store.addEventLocked(incident, IncidentEventAnalysis, "", nil) + store.addEventLocked(incident, IncidentEventNote, "note", nil) + if len(incident.Events) != 1 || incident.Events[0].Type != IncidentEventNote { + t.Fatalf("expected events trimmed to last entry") + } + if incident.Events[0].Summary == "" { + t.Fatalf("expected summary to be set") + } + + store.incidents = append([]*Incident{nil}, store.incidents...) + if store.findOpenIncidentByAlertIDLocked("") != nil { + t.Fatalf("expected nil for empty alert ID") + } + if store.findLatestIncidentByAlertIDLocked("") != nil { + t.Fatalf("expected nil for empty alert ID") + } + if store.findIncidentByIDLocked("") != nil { + t.Fatalf("expected nil for empty incident ID") + } + + oldClosed := time.Now().Add(-2 * time.Hour) + store.incidents = []*Incident{ + nil, + {ID: "old-open", AlertID: "old", Status: IncidentStatusOpen, OpenedAt: time.Now().Add(-2 * time.Hour)}, + {ID: "old-closed", AlertID: "oldc", Status: IncidentStatusResolved, OpenedAt: time.Now().Add(-3 * time.Hour), ClosedAt: &oldClosed}, + {ID: "recent", AlertID: "recent", Status: IncidentStatusOpen, OpenedAt: time.Now().Add(-5 * time.Minute)}, + {ID: "recent2", AlertID: "recent2", Status: IncidentStatusOpen, OpenedAt: time.Now().Add(-4 * time.Minute)}, + } + store.trimLocked() + if len(store.incidents) != 1 || store.incidents[0].ID != "recent2" { + t.Fatalf("expected trim to keep most recent incident") + } +} + +func TestIncidentStore_SaveAsyncAndPersistence(t *testing.T) { + tmpDir := t.TempDir() + store := &IncidentStore{ + incidents: []*Incident{ + {ID: "inc-1", AlertID: "alert-1", Status: IncidentStatusOpen, OpenedAt: time.Now()}, + }, + dataDir: tmpDir, + filePath: filepath.Join(tmpDir, incidentFileName), + } + + store.saveAsync() + + deadline := time.Now().Add(500 * time.Millisecond) + for { + if _, err := os.Stat(store.filePath); err == nil { + break + } + if time.Now().After(deadline) { + t.Fatalf("expected saveAsync to create file") + } + time.Sleep(10 * time.Millisecond) + } +} + +func TestIncidentStore_SaveAsync_Error(t *testing.T) { + tmpDir := t.TempDir() + badDir := filepath.Join(tmpDir, "not-dir") + if err := os.WriteFile(badDir, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + store := &IncidentStore{ + incidents: []*Incident{ + {ID: "inc-err", AlertID: "alert-err", Status: IncidentStatusOpen, OpenedAt: time.Now()}, + }, + dataDir: badDir, + filePath: filepath.Join(badDir, incidentFileName), + } + store.saveAsync() + time.Sleep(20 * time.Millisecond) +} + +func TestIncidentStore_SaveToDisk_Scenarios(t *testing.T) { + t.Run("NoDataDir", func(t *testing.T) { + store := &IncidentStore{} + if err := store.saveToDisk(); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("MkdirError", func(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "not-a-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + store := &IncidentStore{ + dataDir: filePath, + filePath: filepath.Join(filePath, incidentFileName), + } + if err := store.saveToDisk(); err == nil { + t.Fatalf("expected mkdir error") + } + }) + + t.Run("WriteError", func(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, incidentFileName) + tmpFile := filePath + ".tmp" + if err := os.MkdirAll(tmpFile, 0755); err != nil { + t.Fatalf("mkdir tmp: %v", err) + } + store := &IncidentStore{ + dataDir: tmpDir, + filePath: filePath, + } + if err := store.saveToDisk(); err == nil { + t.Fatalf("expected write error") + } + }) + + t.Run("MarshalError", func(t *testing.T) { + tmpDir := t.TempDir() + store := &IncidentStore{ + incidents: []*Incident{ + { + ID: "inc-1", + AlertID: "alert-1", + Status: IncidentStatusOpen, + OpenedAt: time.Now(), + Events: []IncidentEvent{ + { + ID: "evt-1", + Type: IncidentEventNote, + Timestamp: time.Now(), + Summary: "note", + Details: map[string]interface{}{"bad": make(chan int)}, + }, + }, + }, + }, + dataDir: tmpDir, + filePath: filepath.Join(tmpDir, incidentFileName), + } + if err := store.saveToDisk(); err == nil { + t.Fatalf("expected marshal error") + } + }) + + t.Run("RenameError", func(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, incidentFileName) + if err := os.MkdirAll(filePath, 0755); err != nil { + t.Fatalf("mkdir file path: %v", err) + } + store := &IncidentStore{ + incidents: []*Incident{ + {ID: "inc-1", AlertID: "alert-1", Status: IncidentStatusOpen, OpenedAt: time.Now()}, + }, + dataDir: tmpDir, + filePath: filePath, + } + if err := store.saveToDisk(); err == nil { + t.Fatalf("expected rename error") + } + }) + + t.Run("Success", func(t *testing.T) { + tmpDir := t.TempDir() + store := &IncidentStore{ + incidents: []*Incident{ + { + ID: "inc-1", + AlertID: "alert-1", + Status: IncidentStatusOpen, + OpenedAt: time.Now(), + Events: []IncidentEvent{ + {ID: "evt-1", Type: IncidentEventNote, Timestamp: time.Now(), Summary: "note", Details: map[string]interface{}{"k": "v"}}, + }, + }, + }, + dataDir: tmpDir, + filePath: filepath.Join(tmpDir, incidentFileName), + } + if err := store.saveToDisk(); err != nil { + t.Fatalf("saveToDisk error: %v", err) + } + if _, err := os.Stat(store.filePath); err != nil { + t.Fatalf("expected file to exist: %v", err) + } + }) +} + +func TestIncidentStore_LoadFromDisk_Scenarios(t *testing.T) { + t.Run("EmptyFilePath", func(t *testing.T) { + store := &IncidentStore{filePath: ""} + if err := store.loadFromDisk(); err != nil { + t.Fatalf("expected nil error for empty file path, got %v", err) + } + }) + + t.Run("NoFile", func(t *testing.T) { + tmpDir := t.TempDir() + store := &IncidentStore{ + filePath: filepath.Join(tmpDir, incidentFileName), + } + if err := store.loadFromDisk(); err != nil { + t.Fatalf("expected nil error for missing file, got %v", err) + } + }) + + t.Run("ReadError", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, incidentFileName) + if err := os.MkdirAll(path, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + store := &IncidentStore{filePath: path} + if err := store.loadFromDisk(); err == nil { + t.Fatalf("expected read error") + } + }) + + t.Run("StatError", func(t *testing.T) { + tmpDir := t.TempDir() + notDir := filepath.Join(tmpDir, "not-dir") + if err := os.WriteFile(notDir, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + store := &IncidentStore{ + filePath: filepath.Join(notDir, incidentFileName), + } + if err := store.loadFromDisk(); err == nil { + t.Fatalf("expected stat error") + } + }) + + t.Run("FileTooLarge", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, incidentFileName) + file, err := os.Create(path) + if err != nil { + t.Fatalf("create: %v", err) + } + if err := file.Truncate(maxIncidentFileSize + 1); err != nil { + t.Fatalf("truncate: %v", err) + } + if err := file.Close(); err != nil { + t.Fatalf("close: %v", err) + } + store := &IncidentStore{filePath: path} + if err := store.loadFromDisk(); err == nil { + t.Fatalf("expected size error") + } + }) + + t.Run("InvalidJSON", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, incidentFileName) + if err := os.WriteFile(path, []byte("{"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + store := &IncidentStore{filePath: path} + if err := store.loadFromDisk(); err == nil { + t.Fatalf("expected JSON error") + } + }) + + t.Run("Success", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, incidentFileName) + closed := time.Now().Add(-90 * time.Minute) + incidents := []*Incident{ + { + ID: "inc-a", + AlertID: "alert-a", + Status: IncidentStatusResolved, + OpenedAt: time.Now().Add(-2 * time.Hour), + ClosedAt: &closed, + }, + { + ID: "inc-b", + AlertID: "alert-b", + Status: IncidentStatusOpen, + OpenedAt: time.Now().Add(-10 * time.Minute), + }, + } + data, err := json.Marshal(incidents) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + store := &IncidentStore{ + filePath: path, + maxIncidents: 1, + maxAge: 24 * time.Hour, + } + if err := store.loadFromDisk(); err != nil { + t.Fatalf("loadFromDisk error: %v", err) + } + if len(store.incidents) != 1 { + t.Fatalf("expected trimmed incidents, got %d", len(store.incidents)) + } + }) +} + +func TestCloneIncident(t *testing.T) { + if cloneIncident(nil) != nil { + t.Fatalf("expected nil clone") + } + now := time.Now() + ack := now.Add(-5 * time.Minute) + closed := now.Add(-2 * time.Minute) + incident := &Incident{ + ID: "inc-1", + AlertID: "alert-1", + Status: IncidentStatusResolved, + OpenedAt: now.Add(-10 * time.Minute), + AckTime: &ack, + ClosedAt: &closed, + Events: []IncidentEvent{ + { + ID: "evt-1", + Type: IncidentEventNote, + Timestamp: now, + Summary: "note", + Details: map[string]interface{}{"key": "value"}, + }, + { + ID: "evt-2", + Type: IncidentEventAnalysis, + Timestamp: now, + Summary: "analysis", + }, + }, + } + + clone := cloneIncident(incident) + if clone == nil || clone.AckTime == nil || clone.ClosedAt == nil { + t.Fatalf("expected clone with ack and close time") + } + clone.Events[0].Details["key"] = "changed" + if incident.Events[0].Details["key"] == "changed" { + t.Fatalf("expected deep copy of details") + } +} + +func TestRemediationLog_DefaultsAndLog(t *testing.T) { + log := NewRemediationLog(RemediationLogConfig{}) + if log.maxRecords != 500 { + t.Fatalf("expected default max records, got %d", log.maxRecords) + } + + if err := log.Log(RemediationRecord{Problem: "p", Action: "a"}); err != nil { + t.Fatalf("log error: %v", err) + } + if len(log.records) != 1 { + t.Fatalf("expected record logged") + } + if log.records[0].ID == "" || log.records[0].Timestamp.IsZero() { + t.Fatalf("expected ID and Timestamp to be set") + } +} + +func TestRemediationLog_Log_SaveError(t *testing.T) { + tmpDir := t.TempDir() + badDir := filepath.Join(tmpDir, "not-dir") + if err := os.WriteFile(badDir, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + log := &RemediationLog{ + dataDir: badDir, + maxRecords: 1, + } + if err := log.Log(RemediationRecord{Problem: "p", Action: "a"}); err != nil { + t.Fatalf("log error: %v", err) + } + time.Sleep(20 * time.Millisecond) +} + +func TestNewRemediationLog_LoadsFromDisk(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_remediations.json") + records := []RemediationRecord{{ID: "r1", Problem: "p", Action: "a"}} + data, err := json.Marshal(records) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + log := NewRemediationLog(RemediationLogConfig{DataDir: tmpDir}) + if len(log.records) != 1 { + t.Fatalf("expected records loaded") + } +} + +func TestNewRemediationLog_LoadError(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_remediations.json") + if err := os.WriteFile(path, []byte("{"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + log := NewRemediationLog(RemediationLogConfig{DataDir: tmpDir}) + if len(log.records) != 0 { + t.Fatalf("expected no records after load error") + } +} + +func TestRemediationLog_SimilarAndStatsBranches(t *testing.T) { + log := NewRemediationLog(RemediationLogConfig{}) + if matches := log.GetSimilar("a b c", 5); matches != nil { + t.Fatalf("expected nil for no keywords") + } + + log.Log(RemediationRecord{Problem: "memory issue", Action: "a1", Outcome: OutcomePartial}) + log.Log(RemediationRecord{Problem: "memory issue", Action: "a2", Outcome: OutcomeFailed}) + + success := log.GetSuccessfulRemediations("memory issue", 5) + if len(success) != 1 || success[0].Outcome != OutcomePartial { + t.Fatalf("expected partial to be included") + } + + log.Log(RemediationRecord{Problem: "unknown", Action: "a3", Outcome: OutcomeUnknown}) + stats := log.GetRecentRemediationStats(time.Now().Add(-1 * time.Hour)) + if stats["unknown"] == 0 { + t.Fatalf("expected unknown outcome to be counted") + } +} + +func TestRemediationLog_GetSuccessfulRemediations_Limit(t *testing.T) { + log := NewRemediationLog(RemediationLogConfig{}) + log.Log(RemediationRecord{Problem: "disk full", Action: "a1", Outcome: OutcomeResolved}) + log.Log(RemediationRecord{Problem: "disk full", Action: "a2", Outcome: OutcomePartial}) + + results := log.GetSuccessfulRemediations("disk full", 1) + if len(results) != 1 { + t.Fatalf("expected limited results, got %d", len(results)) + } +} + +func TestRemediationLog_FormatAndStats(t *testing.T) { + log := NewRemediationLog(RemediationLogConfig{}) + log.Log(RemediationRecord{ + ResourceID: "res-1", + Problem: "issue", + Action: "action", + Outcome: OutcomeUnknown, + Note: "note", + }) + log.Log(RemediationRecord{ + ResourceID: "res-1", + Problem: "issue", + Action: "action", + Outcome: OutcomePartial, + }) + + formatted := log.FormatForContext("res-1", 5) + if !strings.Contains(formatted, "Note: note") { + t.Fatalf("expected note in formatted context") + } + + stats := log.GetRemediationStats() + if stats["unknown"] != 1 { + t.Fatalf("expected unknown count") + } + if stats["partial"] != 1 { + t.Fatalf("expected partial count") + } +} + +func TestRemediationLog_SaveLoad_Scenarios(t *testing.T) { + t.Run("SaveNoDataDir", func(t *testing.T) { + log := &RemediationLog{} + if err := log.saveToDisk(); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("SaveMissingDir", func(t *testing.T) { + tmpDir := t.TempDir() + missing := filepath.Join(tmpDir, "missing") + log := &RemediationLog{ + dataDir: missing, + records: []RemediationRecord{{ID: "r1", Problem: "p", Action: "a"}}, + } + if err := log.saveToDisk(); err == nil { + t.Fatalf("expected error for missing directory") + } + }) + + t.Run("SaveRenameError", func(t *testing.T) { + tmpDir := t.TempDir() + destDir := filepath.Join(tmpDir, "ai_remediations.json") + if err := os.MkdirAll(destDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + log := &RemediationLog{ + dataDir: tmpDir, + records: []RemediationRecord{{ID: "r1", Problem: "p", Action: "a"}}, + } + if err := log.saveToDisk(); err == nil { + t.Fatalf("expected rename error") + } + }) + + t.Run("SaveSuccess", func(t *testing.T) { + tmpDir := t.TempDir() + log := &RemediationLog{ + dataDir: tmpDir, + records: []RemediationRecord{{ID: "r1", Problem: "p", Action: "a"}}, + } + if err := log.saveToDisk(); err != nil { + t.Fatalf("saveToDisk error: %v", err) + } + if _, err := os.Stat(filepath.Join(tmpDir, "ai_remediations.json")); err != nil { + t.Fatalf("expected file to exist: %v", err) + } + }) + + t.Run("SaveMarshalError", func(t *testing.T) { + tmpDir := t.TempDir() + log := &RemediationLog{ + dataDir: tmpDir, + records: []RemediationRecord{ + { + ID: "r1", + Problem: "p", + Action: "a", + Timestamp: time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC), + }, + }, + } + if err := log.saveToDisk(); err == nil { + t.Fatalf("expected marshal error") + } + }) + + t.Run("LoadEmptyDataDir", func(t *testing.T) { + log := &RemediationLog{dataDir: ""} + if err := log.loadFromDisk(); err != nil { + t.Fatalf("expected nil error for empty dataDir, got %v", err) + } + }) + + t.Run("LoadNoFile", func(t *testing.T) { + log := &RemediationLog{dataDir: t.TempDir()} + if err := log.loadFromDisk(); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("LoadReadError", func(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "not-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + log := &RemediationLog{dataDir: filePath} + if err := log.loadFromDisk(); err == nil { + t.Fatalf("expected read error") + } + }) + + t.Run("LoadTooLarge", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_remediations.json") + file, err := os.Create(path) + if err != nil { + t.Fatalf("create: %v", err) + } + if err := file.Truncate(10<<20 + 1); err != nil { + t.Fatalf("truncate: %v", err) + } + if err := file.Close(); err != nil { + t.Fatalf("close: %v", err) + } + + log := &RemediationLog{dataDir: tmpDir} + if err := log.loadFromDisk(); err == nil { + t.Fatalf("expected size error") + } + }) + + t.Run("LoadInvalidJSON", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_remediations.json") + if err := os.WriteFile(path, []byte("{"), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + log := &RemediationLog{dataDir: tmpDir} + if err := log.loadFromDisk(); err == nil { + t.Fatalf("expected JSON error") + } + }) + + t.Run("LoadSuccess", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ai_remediations.json") + records := []RemediationRecord{ + {ID: "r2", Problem: "p2", Action: "a2", Timestamp: time.Now().Add(-1 * time.Hour)}, + {ID: "r1", Problem: "p1", Action: "a1", Timestamp: time.Now().Add(-2 * time.Hour)}, + } + data, err := json.Marshal(records) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + log := &RemediationLog{ + dataDir: tmpDir, + maxRecords: 1, + } + if err := log.loadFromDisk(); err != nil { + t.Fatalf("loadFromDisk error: %v", err) + } + if len(log.records) != 1 { + t.Fatalf("expected trimmed records, got %d", len(log.records)) + } + }) +} diff --git a/internal/ai/mock_test.go b/internal/ai/mock_test.go index 69f2bea7b..4cda8329d 100644 --- a/internal/ai/mock_test.go +++ b/internal/ai/mock_test.go @@ -52,8 +52,8 @@ type mockThresholdProvider struct { storage float64 } -func (m *mockThresholdProvider) GetNodeCPUThreshold() float64 { return m.nodeCPU } -func (m *mockThresholdProvider) GetNodeMemoryThreshold() float64 { return m.nodeMemory } +func (m *mockThresholdProvider) GetNodeCPUThreshold() float64 { return m.nodeCPU } +func (m *mockThresholdProvider) GetNodeMemoryThreshold() float64 { return m.nodeMemory } func (m *mockThresholdProvider) GetGuestCPUThreshold() float64 { return 0 } func (m *mockThresholdProvider) GetGuestMemoryThreshold() float64 { return m.guestMem } func (m *mockThresholdProvider) GetGuestDiskThreshold() float64 { return m.guestDisk } @@ -61,11 +61,17 @@ func (m *mockThresholdProvider) GetStorageThreshold() float64 { return m.sto type mockResourceProvider struct { ResourceProvider - getAllFunc func() []resources.Resource - getStatsFunc func() resources.StoreStats - getSummaryFunc func() resources.ResourceSummary + getAllFunc func() []resources.Resource + getStatsFunc func() resources.StoreStats + getSummaryFunc func() resources.ResourceSummary getInfrastructureFunc func() []resources.Resource - getWorkloadsFunc func() []resources.Resource + getWorkloadsFunc func() []resources.Resource + getByTypeFunc func(t resources.ResourceType) []resources.Resource + getTopCPUFunc func(limit int, types []resources.ResourceType) []resources.Resource + getTopMemoryFunc func(limit int, types []resources.ResourceType) []resources.Resource + getTopDiskFunc func(limit int, types []resources.ResourceType) []resources.Resource + getRelatedFunc func(resourceID string) map[string][]resources.Resource + findContainerHostFunc func(containerNameOrID string) string } func (m *mockResourceProvider) GetAll() []resources.Resource { @@ -99,23 +105,46 @@ func (m *mockResourceProvider) GetWorkloads() []resources.Resource { return nil } func (m *mockResourceProvider) GetType(t resources.ResourceType) []resources.Resource { return nil } +func (m *mockResourceProvider) GetByType(t resources.ResourceType) []resources.Resource { + if m.getByTypeFunc != nil { + return m.getByTypeFunc(t) + } + return nil +} func (m *mockResourceProvider) GetTopByCPU(limit int, types []resources.ResourceType) []resources.Resource { + if m.getTopCPUFunc != nil { + return m.getTopCPUFunc(limit, types) + } return nil } func (m *mockResourceProvider) GetTopByMemory(limit int, types []resources.ResourceType) []resources.Resource { + if m.getTopMemoryFunc != nil { + return m.getTopMemoryFunc(limit, types) + } return nil } func (m *mockResourceProvider) GetTopByDisk(limit int, types []resources.ResourceType) []resources.Resource { + if m.getTopDiskFunc != nil { + return m.getTopDiskFunc(limit, types) + } return nil } func (m *mockResourceProvider) GetRelated(resourceID string) map[string][]resources.Resource { + if m.getRelatedFunc != nil { + return m.getRelatedFunc(resourceID) + } return nil } -func (m *mockResourceProvider) FindContainerHost(containerNameOrID string) string { return "" } +func (m *mockResourceProvider) FindContainerHost(containerNameOrID string) string { + if m.findContainerHostFunc != nil { + return m.findContainerHostFunc(containerNameOrID) + } + return "" +} type mockAgentServer struct { - agents []agentexec.ConnectedAgent - executeFunc func(ctx context.Context, agentID string, cmd agentexec.ExecuteCommandPayload) (*agentexec.CommandResultPayload, error) + agents []agentexec.ConnectedAgent + executeFunc func(ctx context.Context, agentID string, cmd agentexec.ExecuteCommandPayload) (*agentexec.CommandResultPayload, error) } func (m *mockAgentServer) GetConnectedAgents() []agentexec.ConnectedAgent { diff --git a/internal/ai/resource_context.go b/internal/ai/resource_context.go index 9e415c065..91aa83160 100644 --- a/internal/ai/resource_context.go +++ b/internal/ai/resource_context.go @@ -284,10 +284,6 @@ func (s *Service) buildUnifiedResourceContext() string { } } - if len(sections) == 0 { - return "" - } - result := "\n\n" + strings.Join(sections, "\n") // Limit context size diff --git a/internal/ai/resource_context_test.go b/internal/ai/resource_context_test.go new file mode 100644 index 000000000..d833d5989 --- /dev/null +++ b/internal/ai/resource_context_test.go @@ -0,0 +1,263 @@ +package ai + +import ( + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/agentexec" + "github.com/rcourtman/pulse-go-rewrite/internal/resources" +) + +func TestBuildUnifiedResourceContext_NilProvider(t *testing.T) { + s := &Service{} + if got := s.buildUnifiedResourceContext(); got != "" { + t.Errorf("expected empty context, got %q", got) + } +} + +func TestBuildUnifiedResourceContext_FullContext(t *testing.T) { + nodeWithAgent := resources.Resource{ + ID: "node-1", + Name: "delly", + Type: resources.ResourceTypeNode, + PlatformType: resources.PlatformProxmoxPVE, + ClusterID: "cluster-a", + Status: resources.StatusOnline, + CPU: metricValue(12.3), + Memory: metricValue(45.6), + } + nodeNoAgent := resources.Resource{ + ID: "node-2", + Name: "minipc", + Type: resources.ResourceTypeNode, + PlatformType: resources.PlatformProxmoxPVE, + Status: resources.StatusDegraded, + } + dockerNode := resources.Resource{ + ID: "dock-node", + Name: "dock-node", + Type: resources.ResourceTypeNode, + PlatformType: resources.PlatformDocker, + Status: resources.StatusOnline, + } + host := resources.Resource{ + ID: "host-1", + Name: "barehost", + Type: resources.ResourceTypeHost, + PlatformType: resources.PlatformHostAgent, + Status: resources.StatusOnline, + Identity: &resources.ResourceIdentity{IPs: []string{"192.168.1.10"}}, + CPU: metricValue(5.0), + Memory: metricValue(10.0), + } + dockerHost := resources.Resource{ + ID: "docker-1", + Name: "dockhost", + Type: resources.ResourceTypeDockerHost, + PlatformType: resources.PlatformDocker, + Status: resources.StatusRunning, + } + + vm := resources.Resource{ + ID: "vm-100", + Name: "web-vm", + Type: resources.ResourceTypeVM, + ParentID: "node-1", + PlatformID: "100", + Status: resources.StatusRunning, + Identity: &resources.ResourceIdentity{IPs: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}}, + CPU: metricValue(65.4), + Memory: metricValue(70.2), + } + vm.Alerts = []resources.ResourceAlert{ + {ID: "alert-1", Message: "CPU high", Level: "critical"}, + } + + ct := resources.Resource{ + ID: "ct-200", + Name: "db-ct", + Type: resources.ResourceTypeContainer, + ParentID: "node-1", + PlatformID: "200", + Status: resources.StatusStopped, + } + dockerContainer := resources.Resource{ + ID: "dock-300", + Name: "redis", + Type: resources.ResourceTypeDockerContainer, + ParentID: "docker-1", + Status: resources.StatusRunning, + Disk: metricValue(70.0), + } + dockerStopped := resources.Resource{ + ID: "dock-301", + Name: "cache", + Type: resources.ResourceTypeDockerContainer, + ParentID: "docker-1", + Status: resources.StatusStopped, + } + unknownParent := resources.Resource{ + ID: "vm-999", + Name: "mystery", + Type: resources.ResourceTypeVM, + ParentID: "unknown-parent", + PlatformID: "999", + Status: resources.StatusRunning, + } + orphan := resources.Resource{ + ID: "orphan-1", + Name: "orphan", + Type: resources.ResourceTypeContainer, + Status: resources.StatusRunning, + Identity: &resources.ResourceIdentity{IPs: []string{"172.16.0.5"}}, + } + + infrastructure := []resources.Resource{nodeWithAgent, nodeNoAgent, host, dockerHost, dockerNode} + workloads := []resources.Resource{vm, ct, dockerContainer, dockerStopped, unknownParent, orphan} + all := append(append([]resources.Resource{}, infrastructure...), workloads...) + + stats := resources.StoreStats{ + TotalResources: len(all), + ByType: map[resources.ResourceType]int{ + resources.ResourceTypeNode: 3, + resources.ResourceTypeHost: 1, + resources.ResourceTypeDockerHost: 1, + resources.ResourceTypeVM: 2, + resources.ResourceTypeContainer: 2, + resources.ResourceTypeDockerContainer: 2, + }, + } + + summary := resources.ResourceSummary{ + TotalResources: len(all), + Healthy: 6, + Degraded: 1, + Offline: 4, + WithAlerts: 1, + ByType: map[resources.ResourceType]resources.TypeSummary{ + resources.ResourceTypeVM: {Count: 2, AvgCPUPercent: 40.0, AvgMemoryPercent: 55.0}, + resources.ResourceTypeContainer: {Count: 2}, + }, + } + + mockRP := &mockResourceProvider{ + getStatsFunc: func() resources.StoreStats { + return stats + }, + getInfrastructureFunc: func() []resources.Resource { + return infrastructure + }, + getWorkloadsFunc: func() []resources.Resource { + return workloads + }, + getAllFunc: func() []resources.Resource { + return all + }, + getSummaryFunc: func() resources.ResourceSummary { + return summary + }, + getTopCPUFunc: func(limit int, types []resources.ResourceType) []resources.Resource { + return []resources.Resource{vm} + }, + getTopMemoryFunc: func(limit int, types []resources.ResourceType) []resources.Resource { + return []resources.Resource{host} + }, + getTopDiskFunc: func(limit int, types []resources.ResourceType) []resources.Resource { + return []resources.Resource{dockerContainer} + }, + } + + s := &Service{resourceProvider: mockRP} + s.agentServer = &mockAgentServer{ + agents: []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "delly"}, + }, + } + + got := s.buildUnifiedResourceContext() + if got == "" { + t.Fatal("expected non-empty context") + } + + assertContains := func(substr string) { + t.Helper() + if !strings.Contains(got, substr) { + t.Fatalf("expected context to contain %q", substr) + } + } + + assertContains("## Unified Infrastructure View") + assertContains("Total resources: 11 (Infrastructure: 5, Workloads: 6)") + assertContains("Proxmox VE Nodes") + assertContains("HAS AGENT") + assertContains("NO AGENT") + assertContains("cluster: cluster-a") + assertContains("Standalone Hosts") + assertContains("192.168.1.10") + assertContains("Docker/Podman Hosts") + assertContains("1/2 containers running") + assertContains("Workloads (VMs & Containers)") + assertContains("On delly") + assertContains("On unknown-parent") + assertContains("Other workloads") + assertContains("10.0.0.1, 10.0.0.2") + assertContains("Resources with Active Alerts") + assertContains("CPU high") + assertContains("Infrastructure Summary") + assertContains("Resources with alerts: 1") + assertContains("Average utilization by type") + assertContains("Top CPU Consumers") + assertContains("Top Memory Consumers") + assertContains("Top Disk Usage") +} + +func TestBuildUnifiedResourceContext_TruncatesLargeContext(t *testing.T) { + largeName := strings.Repeat("a", 60000) + + node := resources.Resource{ + ID: "node-1", + Name: "node-1", + DisplayName: largeName, + Type: resources.ResourceTypeNode, + PlatformType: resources.PlatformProxmoxPVE, + Status: resources.StatusOnline, + } + + stats := resources.StoreStats{ + TotalResources: 1, + ByType: map[resources.ResourceType]int{ + resources.ResourceTypeNode: 1, + }, + } + + mockRP := &mockResourceProvider{ + getStatsFunc: func() resources.StoreStats { + return stats + }, + getInfrastructureFunc: func() []resources.Resource { + return []resources.Resource{node} + }, + getWorkloadsFunc: func() []resources.Resource { + return nil + }, + getAllFunc: func() []resources.Resource { + return []resources.Resource{node} + }, + getSummaryFunc: func() resources.ResourceSummary { + return resources.ResourceSummary{TotalResources: 1} + }, + } + + s := &Service{resourceProvider: mockRP} + got := s.buildUnifiedResourceContext() + if !strings.Contains(got, "[... Context truncated ...]") { + t.Fatal("expected context to be truncated") + } + if len(got) <= 50000 { + t.Fatalf("expected truncated context length > 50000, got %d", len(got)) + } +} + +func metricValue(current float64) *resources.MetricValue { + return &resources.MetricValue{Current: current} +} diff --git a/internal/ai/routing_test.go b/internal/ai/routing_test.go index 188238723..de89676d1 100644 --- a/internal/ai/routing_test.go +++ b/internal/ai/routing_test.go @@ -1,9 +1,14 @@ package ai import ( + "os" + "path/filepath" + "strings" "testing" "github.com/rcourtman/pulse-go-rewrite/internal/agentexec" + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/models" ) func TestExtractVMIDFromTargetID(t *testing.T) { @@ -16,21 +21,21 @@ func TestExtractVMIDFromTargetID(t *testing.T) { {"plain vmid", "106", 106}, {"node-vmid", "minipc-106", 106}, {"instance-node-vmid", "delly-minipc-106", 106}, - + // Edge cases with hyphenated names {"hyphenated-node-vmid", "pve-node-01-106", 106}, {"hyphenated-instance-node-vmid", "my-cluster-pve-node-01-106", 106}, - + // Type prefixes {"lxc prefix", "lxc-106", 106}, {"vm prefix", "vm-106", 106}, {"ct prefix", "ct-106", 106}, - + // Non-numeric - should return 0 {"non-numeric", "mycontainer", 0}, {"no-vmid", "node-name", 0}, {"empty", "", 0}, - + // Large VMIDs (Proxmox uses up to 999999999) {"large vmid", "node-999999", 999999}, } @@ -53,18 +58,18 @@ func TestRoutingError(t *testing.T) { Reason: "No agent connected to node \"minipc\"", Suggestion: "Install pulse-agent on minipc", } - + want := "No agent connected to node \"minipc\". Install pulse-agent on minipc" if err.Error() != want { t.Errorf("Error() = %q, want %q", err.Error(), want) } }) - + t.Run("without suggestion", func(t *testing.T) { err := &RoutingError{ Reason: "No agents connected", } - + want := "No agents connected" if err.Error() != want { t.Errorf("Error() = %q, want %q", err.Error(), want) @@ -74,22 +79,22 @@ func TestRoutingError(t *testing.T) { func TestRouteToAgent_NoAgents(t *testing.T) { s := &Service{} - + req := ExecuteRequest{ TargetType: "container", TargetID: "minipc-106", } - + _, err := s.routeToAgent(req, "pct exec 106 -- hostname", nil) if err == nil { t.Error("expected error for no agents, got nil") } - + routingErr, ok := err.(*RoutingError) if !ok { t.Fatalf("expected RoutingError, got %T", err) } - + if routingErr.Suggestion == "" { t.Error("expected suggestion in error") } @@ -97,19 +102,19 @@ func TestRouteToAgent_NoAgents(t *testing.T) { func TestRouteToAgent_ExactMatch(t *testing.T) { s := &Service{} - + agents := []agentexec.ConnectedAgent{ {AgentID: "agent-1", Hostname: "delly"}, {AgentID: "agent-2", Hostname: "minipc"}, {AgentID: "agent-3", Hostname: "pimox"}, } - + tests := []struct { - name string - req ExecuteRequest - command string - wantAgentID string - wantHostname string + name string + req ExecuteRequest + command string + wantAgentID string + wantHostname string }{ { name: "route by context node", @@ -151,11 +156,11 @@ func TestRouteToAgent_ExactMatch(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - + if result.AgentID != tt.wantAgentID { t.Errorf("AgentID = %q, want %q", result.AgentID, tt.wantAgentID) } - + if result.AgentHostname != tt.wantHostname { t.Errorf("AgentHostname = %q, want %q", result.AgentHostname, tt.wantHostname) } @@ -165,28 +170,28 @@ func TestRouteToAgent_ExactMatch(t *testing.T) { func TestRouteToAgent_NoSubstringMatching(t *testing.T) { s := &Service{} - + // Agent "mini" should NOT match node "minipc" agents := []agentexec.ConnectedAgent{ {AgentID: "agent-1", Hostname: "mini"}, {AgentID: "agent-2", Hostname: "pc"}, } - + req := ExecuteRequest{ TargetType: "container", Context: map[string]interface{}{"node": "minipc"}, } - + _, err := s.routeToAgent(req, "hostname", agents) if err == nil { t.Error("expected error when no exact match, got nil (substring matching may be occurring)") } - + routingErr, ok := err.(*RoutingError) if !ok { t.Fatalf("expected RoutingError, got %T", err) } - + if routingErr.TargetNode != "minipc" { t.Errorf("TargetNode = %q, want %q", routingErr.TargetNode, "minipc") } @@ -194,21 +199,21 @@ func TestRouteToAgent_NoSubstringMatching(t *testing.T) { func TestRouteToAgent_CaseInsensitive(t *testing.T) { s := &Service{} - + agents := []agentexec.ConnectedAgent{ {AgentID: "agent-1", Hostname: "MiniPC"}, } - + req := ExecuteRequest{ TargetType: "container", Context: map[string]interface{}{"node": "minipc"}, // lowercase } - + result, err := s.routeToAgent(req, "hostname", agents) if err != nil { t.Fatalf("expected case-insensitive match, got error: %v", err) } - + if result.AgentID != "agent-1" { t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1") } @@ -216,22 +221,22 @@ func TestRouteToAgent_CaseInsensitive(t *testing.T) { func TestRouteToAgent_HyphenatedNodeNames(t *testing.T) { s := &Service{} - + agents := []agentexec.ConnectedAgent{ {AgentID: "agent-1", Hostname: "pve-node-01"}, {AgentID: "agent-2", Hostname: "pve-node-02"}, } - + req := ExecuteRequest{ TargetType: "container", Context: map[string]interface{}{"node": "pve-node-02"}, } - + result, err := s.routeToAgent(req, "hostname", agents) if err != nil { t.Fatalf("unexpected error for hyphenated node names: %v", err) } - + if result.AgentID != "agent-2" { t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-2") } @@ -239,38 +244,441 @@ func TestRouteToAgent_HyphenatedNodeNames(t *testing.T) { func TestRouteToAgent_ActionableErrorMessages(t *testing.T) { s := &Service{} - + agents := []agentexec.ConnectedAgent{ {AgentID: "agent-1", Hostname: "delly"}, } - + req := ExecuteRequest{ TargetType: "container", Context: map[string]interface{}{"node": "minipc"}, } - + _, err := s.routeToAgent(req, "hostname", agents) if err == nil { t.Fatal("expected error, got nil") } - + routingErr, ok := err.(*RoutingError) if !ok { t.Fatalf("expected RoutingError, got %T", err) } - + // Error should mention the target node if routingErr.TargetNode != "minipc" { t.Errorf("TargetNode = %q, want %q", routingErr.TargetNode, "minipc") } - + // Error should list available agents if len(routingErr.AvailableAgents) == 0 { t.Error("expected available agents in error") } - + // Error should have actionable suggestion if routingErr.Suggestion == "" { t.Error("expected suggestion in error message") } } + +func TestRoutingError_ForAI(t *testing.T) { + t.Run("clarification", func(t *testing.T) { + err := &RoutingError{ + Reason: "Cannot determine which host should execute this command", + AvailableAgents: []string{"delly", "pimox"}, + AskForClarification: true, + } + + msg := err.ForAI() + if !strings.Contains(msg, "ROUTING_CLARIFICATION_NEEDED") { + t.Errorf("expected clarification marker, got %q", msg) + } + if !strings.Contains(msg, "delly, pimox") { + t.Errorf("expected available hosts list, got %q", msg) + } + if !strings.Contains(msg, err.Reason) { + t.Errorf("expected reason to be included, got %q", msg) + } + }) + + t.Run("fallback", func(t *testing.T) { + err := &RoutingError{ + Reason: "No agents connected", + AskForClarification: true, + } + if err.ForAI() != err.Error() { + t.Errorf("expected ForAI to fall back to Error, got %q", err.ForAI()) + } + }) +} + +func TestRouteToAgent_VMIDRoutingWithInstance(t *testing.T) { + stateProvider := &mockStateProvider{ + state: models.StateSnapshot{ + Containers: []models.Container{ + {VMID: 106, Node: "node-b", Name: "ct-b", Instance: "cluster-b"}, + {VMID: 106, Node: "node-a", Name: "ct-a", Instance: "cluster-b"}, + }, + }, + } + + s := &Service{stateProvider: stateProvider} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-a", Hostname: "node-a"}, + {AgentID: "agent-b", Hostname: "node-b"}, + } + req := ExecuteRequest{ + TargetType: "container", + Context: map[string]interface{}{"instance": "cluster-b"}, + } + + result, err := s.routeToAgent(req, "pct exec 106 -- hostname", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.AgentID != "agent-b" { + t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-b") + } + if result.RoutingMethod != "vmid_lookup_with_instance" { + t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "vmid_lookup_with_instance") + } +} + +func TestRouteToAgent_VMIDCollision(t *testing.T) { + stateProvider := &mockStateProvider{ + state: models.StateSnapshot{ + VMs: []models.VM{ + {VMID: 200, Node: "node-a", Name: "vm-a", Instance: "cluster-a"}, + {VMID: 200, Node: "node-b", Name: "vm-b", Instance: "cluster-b"}, + }, + }, + } + + s := &Service{stateProvider: stateProvider} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-a", Hostname: "node-a"}, + {AgentID: "agent-b", Hostname: "node-b"}, + } + req := ExecuteRequest{ + TargetType: "vm", + } + + _, err := s.routeToAgent(req, "pct exec 200 -- hostname", agents) + if err == nil { + t.Fatal("expected error for VMID collision, got nil") + } + routingErr, ok := err.(*RoutingError) + if !ok { + t.Fatalf("expected RoutingError, got %T", err) + } + if routingErr.TargetVMID != 200 { + t.Errorf("TargetVMID = %d, want %d", routingErr.TargetVMID, 200) + } + if !strings.Contains(routingErr.Reason, "exists on multiple nodes") { + t.Errorf("expected collision reason, got %q", routingErr.Reason) + } +} + +func TestRouteToAgent_VMIDNotFoundFallsBackToContext(t *testing.T) { + stateProvider := &mockStateProvider{} + s := &Service{stateProvider: stateProvider} + + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "minipc"}, + } + req := ExecuteRequest{ + TargetType: "container", + Context: map[string]interface{}{"node": "minipc"}, + } + + result, err := s.routeToAgent(req, "pct exec 106 -- hostname", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RoutingMethod != "context_node" { + t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "context_node") + } + if len(result.Warnings) != 1 || !strings.Contains(result.Warnings[0], "VMID 106 not found") { + t.Errorf("expected warning about missing VMID, got %v", result.Warnings) + } +} + +func TestRouteToAgent_ResourceProviderContexts(t *testing.T) { + s := &Service{} + mockRP := &mockResourceProvider{ + findContainerHostFunc: func(containerNameOrID string) string { + if containerNameOrID == "" { + return "" + } + return "rp-host" + }, + } + s.resourceProvider = mockRP + + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "rp-host"}, + } + + tests := []struct { + name string + key string + }{ + {name: "containerName", key: "containerName"}, + {name: "name", key: "name"}, + {name: "guestName", key: "guestName"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := ExecuteRequest{ + TargetType: "container", + Context: map[string]interface{}{tt.key: "workload"}, + } + + result, err := s.routeToAgent(req, "docker ps", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RoutingMethod != "resource_provider_lookup" { + t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "resource_provider_lookup") + } + }) + } +} + +func TestRouteToAgent_TargetIDLookup(t *testing.T) { + stateProvider := &mockStateProvider{ + state: models.StateSnapshot{ + VMs: []models.VM{ + {VMID: 222, Node: "node-vm", Name: "vm-222", Instance: "cluster-a"}, + }, + }, + } + + s := &Service{stateProvider: stateProvider} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "node-vm"}, + } + req := ExecuteRequest{ + TargetType: "vm", + TargetID: "vm-222", + Context: map[string]interface{}{"instance": "cluster-a"}, + } + + result, err := s.routeToAgent(req, "status", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RoutingMethod != "target_id_vmid_lookup" { + t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "target_id_vmid_lookup") + } + if result.AgentID != "agent-1" { + t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1") + } +} + +func TestRouteToAgent_VMIDLookupSingleMatch(t *testing.T) { + stateProvider := &mockStateProvider{ + state: models.StateSnapshot{ + VMs: []models.VM{ + {VMID: 101, Node: "node-1", Name: "vm-one"}, + }, + }, + } + + s := &Service{stateProvider: stateProvider} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "node-1"}, + } + req := ExecuteRequest{ + TargetType: "vm", + } + + result, err := s.routeToAgent(req, "qm start 101", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RoutingMethod != "vmid_lookup" { + t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "vmid_lookup") + } + if result.AgentID != "agent-1" { + t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-1") + } +} + +func TestRouteToAgent_ClusterPeer(t *testing.T) { + tmp := t.TempDir() + persistence := config.NewConfigPersistence(tmp) + err := persistence.SaveNodesConfig([]config.PVEInstance{ + { + Name: "node-a", + IsCluster: true, + ClusterName: "cluster-a", + ClusterEndpoints: []config.ClusterEndpoint{ + {NodeName: "node-a"}, + {NodeName: "node-b"}, + }, + }, + }, nil, nil) + if err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + s := &Service{persistence: persistence} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-b", Hostname: "node-b"}, + } + req := ExecuteRequest{ + TargetType: "vm", + Context: map[string]interface{}{"node": "node-a"}, + } + + result, err := s.routeToAgent(req, "hostname", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.ClusterPeer { + t.Error("expected cluster peer routing") + } + if result.AgentID != "agent-b" { + t.Errorf("AgentID = %q, want %q", result.AgentID, "agent-b") + } +} + +func TestRouteToAgent_SingleAgentFallbackForHost(t *testing.T) { + s := &Service{} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "only"}, + } + req := ExecuteRequest{ + TargetType: "host", + } + + result, err := s.routeToAgent(req, "uptime", agents) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RoutingMethod != "single_agent_fallback" { + t.Errorf("RoutingMethod = %q, want %q", result.RoutingMethod, "single_agent_fallback") + } + if len(result.Warnings) != 1 { + t.Errorf("expected warning for single agent fallback, got %v", result.Warnings) + } +} + +func TestRouteToAgent_AsksForClarification(t *testing.T) { + s := &Service{} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "delly"}, + {AgentID: "agent-2", Hostname: "pimox"}, + } + + req := ExecuteRequest{ + TargetType: "vm", + } + + _, err := s.routeToAgent(req, "hostname", agents) + if err == nil { + t.Fatal("expected error, got nil") + } + routingErr, ok := err.(*RoutingError) + if !ok { + t.Fatalf("expected RoutingError, got %T", err) + } + if !routingErr.AskForClarification { + t.Error("expected AskForClarification to be true") + } + if len(routingErr.AvailableAgents) != 2 { + t.Errorf("expected available agents, got %v", routingErr.AvailableAgents) + } +} + +func TestFindClusterPeerAgent_NoPersistence(t *testing.T) { + s := &Service{} + if got := s.findClusterPeerAgent("node-a", nil); got != "" { + t.Errorf("expected empty result, got %q", got) + } +} + +func TestFindClusterPeerAgent_LoadError(t *testing.T) { + tmp := t.TempDir() + persistence := config.NewConfigPersistence(tmp) + nodesPath := filepath.Join(tmp, "nodes.enc") + if err := os.Mkdir(nodesPath, 0700); err != nil { + t.Fatalf("Mkdir nodes.enc: %v", err) + } + + s := &Service{persistence: persistence} + if got := s.findClusterPeerAgent("node-a", nil); got != "" { + t.Errorf("expected empty result, got %q", got) + } +} + +func TestFindClusterPeerAgent_NotCluster(t *testing.T) { + tmp := t.TempDir() + persistence := config.NewConfigPersistence(tmp) + err := persistence.SaveNodesConfig([]config.PVEInstance{ + {Name: "node-a"}, + }, nil, nil) + if err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + s := &Service{persistence: persistence} + if got := s.findClusterPeerAgent("node-a", nil); got != "" { + t.Errorf("expected empty result, got %q", got) + } +} + +func TestFindClusterPeerAgent_NoAgentMatch(t *testing.T) { + tmp := t.TempDir() + persistence := config.NewConfigPersistence(tmp) + err := persistence.SaveNodesConfig([]config.PVEInstance{ + { + Name: "node-a", + IsCluster: true, + ClusterName: "cluster-a", + ClusterEndpoints: []config.ClusterEndpoint{ + {NodeName: "node-a"}, + {NodeName: "node-b"}, + }, + }, + }, nil, nil) + if err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + s := &Service{persistence: persistence} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-1", Hostname: "node-c"}, + } + if got := s.findClusterPeerAgent("node-a", agents); got != "" { + t.Errorf("expected empty result, got %q", got) + } +} + +func TestFindClusterPeerAgent_EndpointMatch(t *testing.T) { + tmp := t.TempDir() + persistence := config.NewConfigPersistence(tmp) + err := persistence.SaveNodesConfig([]config.PVEInstance{ + { + Name: "cluster-master", + IsCluster: true, + ClusterName: "cluster-a", + ClusterEndpoints: []config.ClusterEndpoint{ + {NodeName: "node-a"}, + {NodeName: "node-b"}, + }, + }, + }, nil, nil) + if err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + s := &Service{persistence: persistence} + agents := []agentexec.ConnectedAgent{ + {AgentID: "agent-b", Hostname: "node-b"}, + } + if got := s.findClusterPeerAgent("node-a", agents); got != "agent-b" { + t.Errorf("expected peer agent, got %q", got) + } +} diff --git a/internal/ceph/collector.go b/internal/ceph/collector.go index d6398b7e7..13df19103 100644 --- a/internal/ceph/collector.go +++ b/internal/ceph/collector.go @@ -12,24 +12,33 @@ import ( "time" ) +var commandRunner = func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + return stdout.Bytes(), stderr.Bytes(), err +} + // ClusterStatus represents the complete Ceph cluster status as collected by the agent. type ClusterStatus struct { - FSID string `json:"fsid"` - Health HealthStatus `json:"health"` - MonMap MonitorMap `json:"monMap,omitempty"` - MgrMap ManagerMap `json:"mgrMap,omitempty"` - OSDMap OSDMap `json:"osdMap"` - PGMap PGMap `json:"pgMap"` - Pools []Pool `json:"pools,omitempty"` - Services []ServiceInfo `json:"services,omitempty"` - CollectedAt time.Time `json:"collectedAt"` + FSID string `json:"fsid"` + Health HealthStatus `json:"health"` + MonMap MonitorMap `json:"monMap,omitempty"` + MgrMap ManagerMap `json:"mgrMap,omitempty"` + OSDMap OSDMap `json:"osdMap"` + PGMap PGMap `json:"pgMap"` + Pools []Pool `json:"pools,omitempty"` + Services []ServiceInfo `json:"services,omitempty"` + CollectedAt time.Time `json:"collectedAt"` } // HealthStatus represents Ceph cluster health. type HealthStatus struct { - Status string `json:"status"` // HEALTH_OK, HEALTH_WARN, HEALTH_ERR - Checks map[string]Check `json:"checks,omitempty"` - Summary []HealthSummary `json:"summary,omitempty"` + Status string `json:"status"` // HEALTH_OK, HEALTH_WARN, HEALTH_ERR + Checks map[string]Check `json:"checks,omitempty"` + Summary []HealthSummary `json:"summary,omitempty"` } // Check represents a health check detail. @@ -70,28 +79,28 @@ type ManagerMap struct { // OSDMap represents OSD status summary. type OSDMap struct { - Epoch int `json:"epoch"` - NumOSDs int `json:"numOsds"` - NumUp int `json:"numUp"` - NumIn int `json:"numIn"` - NumDown int `json:"numDown,omitempty"` - NumOut int `json:"numOut,omitempty"` + Epoch int `json:"epoch"` + NumOSDs int `json:"numOsds"` + NumUp int `json:"numUp"` + NumIn int `json:"numIn"` + NumDown int `json:"numDown,omitempty"` + NumOut int `json:"numOut,omitempty"` } // PGMap represents placement group statistics. type PGMap struct { - NumPGs int `json:"numPgs"` - BytesTotal uint64 `json:"bytesTotal"` - BytesUsed uint64 `json:"bytesUsed"` - BytesAvailable uint64 `json:"bytesAvailable"` - DataBytes uint64 `json:"dataBytes,omitempty"` - UsagePercent float64 `json:"usagePercent"` - DegradedRatio float64 `json:"degradedRatio,omitempty"` - MisplacedRatio float64 `json:"misplacedRatio,omitempty"` - ReadBytesPerSec uint64 `json:"readBytesPerSec,omitempty"` - WriteBytesPerSec uint64 `json:"writeBytesPerSec,omitempty"` - ReadOpsPerSec uint64 `json:"readOpsPerSec,omitempty"` - WriteOpsPerSec uint64 `json:"writeOpsPerSec,omitempty"` + NumPGs int `json:"numPgs"` + BytesTotal uint64 `json:"bytesTotal"` + BytesUsed uint64 `json:"bytesUsed"` + BytesAvailable uint64 `json:"bytesAvailable"` + DataBytes uint64 `json:"dataBytes,omitempty"` + UsagePercent float64 `json:"usagePercent"` + DegradedRatio float64 `json:"degradedRatio,omitempty"` + MisplacedRatio float64 `json:"misplacedRatio,omitempty"` + ReadBytesPerSec uint64 `json:"readBytesPerSec,omitempty"` + WriteBytesPerSec uint64 `json:"writeBytesPerSec,omitempty"` + ReadOpsPerSec uint64 `json:"readOpsPerSec,omitempty"` + WriteOpsPerSec uint64 `json:"writeOpsPerSec,omitempty"` } // Pool represents a Ceph pool. @@ -106,16 +115,16 @@ type Pool struct { // ServiceInfo represents a Ceph service summary. type ServiceInfo struct { - Type string `json:"type"` // mon, mgr, osd, mds, rgw - Running int `json:"running"` - Total int `json:"total"` - Daemons []string `json:"daemons,omitempty"` + Type string `json:"type"` // mon, mgr, osd, mds, rgw + Running int `json:"running"` + Total int `json:"total"` + Daemons []string `json:"daemons,omitempty"` } // IsAvailable checks if the ceph CLI is available on the system. func IsAvailable(ctx context.Context) bool { - cmd := exec.CommandContext(ctx, "which", "ceph") - return cmd.Run() == nil + _, _, err := commandRunner(ctx, "which", "ceph") + return err == nil } // Collect gathers Ceph cluster status using the ceph CLI. @@ -160,17 +169,13 @@ func runCephCommand(ctx context.Context, args ...string) ([]byte, error) { cmdCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - cmd := exec.CommandContext(cmdCtx, "ceph", args...) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("ceph %s failed: %w (stderr: %s)", - strings.Join(args, " "), err, stderr.String()) + stdout, stderr, err := commandRunner(cmdCtx, "ceph", args...) + if err != nil { + return nil, fmt.Errorf("ceph %s failed: %w (stderr: %s)", + strings.Join(args, " "), err, string(stderr)) } - return stdout.Bytes(), nil + return stdout, nil } // parseStatus parses the output of `ceph status --format json`. @@ -178,8 +183,8 @@ func parseStatus(data []byte) (*ClusterStatus, error) { var raw struct { FSID string `json:"fsid"` Health struct { - Status string `json:"status"` - Checks map[string]struct { + Status string `json:"status"` + Checks map[string]struct { Severity string `json:"severity"` Summary struct { Message string `json:"message"` @@ -198,10 +203,10 @@ func parseStatus(data []byte) (*ClusterStatus, error) { } `json:"mons"` } `json:"monmap"` MgrMap struct { - Available bool `json:"available"` - NumActive int `json:"num_active_name,omitempty"` + Available bool `json:"available"` + NumActive int `json:"num_active_name,omitempty"` ActiveName string `json:"active_name"` - Standbys []struct { + Standbys []struct { Name string `json:"name"` } `json:"standbys"` } `json:"mgrmap"` @@ -212,17 +217,17 @@ func parseStatus(data []byte) (*ClusterStatus, error) { NumIn int `json:"num_in_osds"` } `json:"osdmap"` PGMap struct { - NumPGs int `json:"num_pgs"` - BytesTotal uint64 `json:"bytes_total"` - BytesUsed uint64 `json:"bytes_used"` - BytesAvail uint64 `json:"bytes_avail"` - DataBytes uint64 `json:"data_bytes"` - DegradedRatio float64 `json:"degraded_ratio"` - MisplacedRatio float64 `json:"misplaced_ratio"` - ReadBytesPerSec uint64 `json:"read_bytes_sec"` - WriteBytesPerSec uint64 `json:"write_bytes_sec"` - ReadOpsPerSec uint64 `json:"read_op_per_sec"` - WriteOpsPerSec uint64 `json:"write_op_per_sec"` + NumPGs int `json:"num_pgs"` + BytesTotal uint64 `json:"bytes_total"` + BytesUsed uint64 `json:"bytes_used"` + BytesAvail uint64 `json:"bytes_avail"` + DataBytes uint64 `json:"data_bytes"` + DegradedRatio float64 `json:"degraded_ratio"` + MisplacedRatio float64 `json:"misplaced_ratio"` + ReadBytesPerSec uint64 `json:"read_bytes_sec"` + WriteBytesPerSec uint64 `json:"write_bytes_sec"` + ReadOpsPerSec uint64 `json:"read_op_per_sec"` + WriteOpsPerSec uint64 `json:"write_op_per_sec"` } `json:"pgmap"` } @@ -255,13 +260,13 @@ func parseStatus(data []byte) (*ClusterStatus, error) { NumOut: raw.OSDMap.NumOSD - raw.OSDMap.NumIn, }, PGMap: PGMap{ - NumPGs: raw.PGMap.NumPGs, - BytesTotal: raw.PGMap.BytesTotal, - BytesUsed: raw.PGMap.BytesUsed, - BytesAvailable: raw.PGMap.BytesAvail, - DataBytes: raw.PGMap.DataBytes, - DegradedRatio: raw.PGMap.DegradedRatio, - MisplacedRatio: raw.PGMap.MisplacedRatio, + NumPGs: raw.PGMap.NumPGs, + BytesTotal: raw.PGMap.BytesTotal, + BytesUsed: raw.PGMap.BytesUsed, + BytesAvailable: raw.PGMap.BytesAvail, + DataBytes: raw.PGMap.DataBytes, + DegradedRatio: raw.PGMap.DegradedRatio, + MisplacedRatio: raw.PGMap.MisplacedRatio, ReadBytesPerSec: raw.PGMap.ReadBytesPerSec, WriteBytesPerSec: raw.PGMap.WriteBytesPerSec, ReadOpsPerSec: raw.PGMap.ReadOpsPerSec, diff --git a/internal/ceph/collector_test.go b/internal/ceph/collector_test.go index 9c9485e0e..5afb8d0b2 100644 --- a/internal/ceph/collector_test.go +++ b/internal/ceph/collector_test.go @@ -1,9 +1,32 @@ package ceph import ( + "context" + "errors" + "strings" "testing" ) +func withCommandRunner(t *testing.T, fn func(ctx context.Context, name string, args ...string) ([]byte, []byte, error)) { + t.Helper() + orig := commandRunner + commandRunner = fn + t.Cleanup(func() { commandRunner = orig }) +} + +func TestCommandRunner_Default(t *testing.T) { + stdout, stderr, err := commandRunner(context.Background(), "sh", "-c", "true") + if err != nil { + t.Fatalf("commandRunner error: %v", err) + } + if len(stdout) != 0 { + t.Fatalf("unexpected stdout: %q", string(stdout)) + } + if len(stderr) != 0 { + t.Fatalf("unexpected stderr: %q", string(stderr)) + } +} + func TestParseStatus(t *testing.T) { data := []byte(`{ "fsid":"fsid-123", @@ -70,6 +93,68 @@ func TestParseStatus(t *testing.T) { } } +func TestIsAvailable(t *testing.T) { + t.Run("available", func(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name != "which" || len(args) != 1 || args[0] != "ceph" { + t.Fatalf("unexpected command: %s %v", name, args) + } + return nil, nil, nil + }) + + if !IsAvailable(context.Background()) { + t.Fatalf("expected available") + } + }) + + t.Run("missing", func(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + return nil, nil, errors.New("missing") + }) + + if IsAvailable(context.Background()) { + t.Fatalf("expected unavailable") + } + }) +} + +func TestRunCephCommand(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name != "ceph" { + t.Fatalf("unexpected command name: %s", name) + } + if len(args) < 1 || args[0] != "status" { + t.Fatalf("unexpected args: %v", args) + } + return []byte(`{"ok":true}`), nil, nil + }) + + out, err := runCephCommand(context.Background(), "status", "--format", "json") + if err != nil { + t.Fatalf("runCephCommand error: %v", err) + } + if string(out) != `{"ok":true}` { + t.Fatalf("unexpected output: %s", string(out)) + } +} + +func TestRunCephCommandError(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + return nil, []byte("bad"), errors.New("boom") + }) + + _, err := runCephCommand(context.Background(), "status", "--format", "json") + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "ceph status --format json failed") { + t.Fatalf("unexpected error message: %v", err) + } + if !strings.Contains(err.Error(), "stderr: bad") { + t.Fatalf("expected stderr in error: %v", err) + } +} + func TestParseStatusInvalidJSON(t *testing.T) { _, err := parseStatus([]byte(`{not-json}`)) if err == nil { @@ -104,6 +189,181 @@ func TestParseDFInvalidJSON(t *testing.T) { } } +func TestCollect_NotAvailable(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name == "which" { + return nil, nil, errors.New("missing") + } + t.Fatalf("unexpected command: %s %v", name, args) + return nil, nil, nil + }) + + status, err := Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status != nil { + t.Fatalf("expected nil status") + } +} + +func TestCollect_StatusError(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name == "which" { + return nil, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "status" { + return nil, []byte("boom"), errors.New("status failed") + } + t.Fatalf("unexpected command: %s %v", name, args) + return nil, nil, nil + }) + + status, err := Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status != nil { + t.Fatalf("expected nil status") + } +} + +func TestCollect_ParseStatusError(t *testing.T) { + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name == "which" { + return nil, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "status" { + return []byte(`{not-json}`), nil, nil + } + t.Fatalf("unexpected command: %s %v", name, args) + return nil, nil, nil + }) + + _, err := Collect(context.Background()) + if err == nil { + t.Fatalf("expected parse error") + } +} + +func TestCollect_DFCommandError(t *testing.T) { + statusJSON := []byte(`{ + "fsid":"fsid-1", + "health":{"status":"HEALTH_OK","checks":{}}, + "monmap":{"epoch":1,"mons":[]}, + "mgrmap":{"available":false,"active_name":"","standbys":[]}, + "osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0}, + "pgmap":{"num_pgs":0,"bytes_total":100,"bytes_used":50,"bytes_avail":50, + "data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0, + "read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0} + }`) + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name == "which" { + return nil, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "status" { + return statusJSON, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "df" { + return nil, []byte("df failed"), errors.New("df error") + } + t.Fatalf("unexpected command: %s %v", name, args) + return nil, nil, nil + }) + + status, err := Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status == nil { + t.Fatalf("expected status") + } + if status.PGMap.UsagePercent == 0 { + t.Fatalf("expected usage percent from status") + } +} + +func TestCollect_DFParseError(t *testing.T) { + statusJSON := []byte(`{ + "fsid":"fsid-1", + "health":{"status":"HEALTH_OK","checks":{}}, + "monmap":{"epoch":1,"mons":[]}, + "mgrmap":{"available":false,"active_name":"","standbys":[]}, + "osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0}, + "pgmap":{"num_pgs":0,"bytes_total":0,"bytes_used":0,"bytes_avail":0, + "data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0, + "read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0} + }`) + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name == "which" { + return nil, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "status" { + return statusJSON, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "df" { + return []byte(`{not-json}`), nil, nil + } + t.Fatalf("unexpected command: %s %v", name, args) + return nil, nil, nil + }) + + status, err := Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status == nil { + t.Fatalf("expected status") + } + if status.PGMap.UsagePercent != 0 { + t.Fatalf("expected usage percent to remain 0, got %v", status.PGMap.UsagePercent) + } +} + +func TestCollect_UsagePercentFromDF(t *testing.T) { + statusJSON := []byte(`{ + "fsid":"fsid-1", + "health":{"status":"HEALTH_OK","checks":{}}, + "monmap":{"epoch":1,"mons":[]}, + "mgrmap":{"available":false,"active_name":"","standbys":[]}, + "osdmap":{"epoch":1,"num_osds":0,"num_up_osds":0,"num_in_osds":0}, + "pgmap":{"num_pgs":0,"bytes_total":0,"bytes_used":0,"bytes_avail":0, + "data_bytes":0,"degraded_ratio":0,"misplaced_ratio":0, + "read_bytes_sec":0,"write_bytes_sec":0,"read_op_per_sec":0,"write_op_per_sec":0} + }`) + dfJSON := []byte(`{ + "stats":{"total_bytes":1000,"total_used_bytes":500,"percent_used":0.5}, + "pools":[] + }`) + withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { + if name == "which" { + return nil, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "status" { + return statusJSON, nil, nil + } + if name == "ceph" && len(args) > 0 && args[0] == "df" { + return dfJSON, nil, nil + } + t.Fatalf("unexpected command: %s %v", name, args) + return nil, nil, nil + }) + + status, err := Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status == nil { + t.Fatalf("expected status") + } + if status.PGMap.UsagePercent != 50 { + t.Fatalf("expected usage percent from df, got %v", status.PGMap.UsagePercent) + } + if status.CollectedAt.IsZero() { + t.Fatalf("expected collected timestamp set") + } +} + func TestBoolToInt(t *testing.T) { if boolToInt(true) != 1 { t.Fatalf("expected boolToInt(true)=1") diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index 3917dcd96..c9a2880a5 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -14,6 +14,16 @@ import ( "github.com/rs/zerolog/log" ) +var defaultDataDirFn = utils.GetDataDir + +var legacyKeyPath = "/etc/pulse/.encryption.key" + +var randReader = rand.Reader + +var newCipher = aes.NewCipher + +var newGCM = cipher.NewGCM + // CryptoManager handles encryption/decryption of sensitive data type CryptoManager struct { key []byte @@ -23,7 +33,7 @@ type CryptoManager struct { // NewCryptoManagerAt creates a new crypto manager with an explicit data directory override. func NewCryptoManagerAt(dataDir string) (*CryptoManager, error) { if dataDir == "" { - dataDir = utils.GetDataDir() + dataDir = defaultDataDirFn() } keyPath := filepath.Join(dataDir, ".encryption.key") @@ -41,11 +51,12 @@ func NewCryptoManagerAt(dataDir string) (*CryptoManager, error) { // getOrCreateKeyAt gets the encryption key or creates one if it doesn't exist func getOrCreateKeyAt(dataDir string) ([]byte, error) { if dataDir == "" { - dataDir = utils.GetDataDir() + dataDir = defaultDataDirFn() } keyPath := filepath.Join(dataDir, ".encryption.key") - oldKeyPath := "/etc/pulse/.encryption.key" + oldKeyPath := legacyKeyPath + oldKeyDir := filepath.Dir(oldKeyPath) log.Debug(). Str("dataDir", dataDir). @@ -78,7 +89,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) { // Check for key in old location and migrate if found (only if paths differ) // CRITICAL: This code deletes the encryption key at oldKeyPath after migrating it. // Adding extensive logging to diagnose recurring key deletion bug. - if dataDir != "/etc/pulse" && keyPath != oldKeyPath { + if dataDir != oldKeyDir && keyPath != oldKeyPath { log.Warn(). Str("dataDir", dataDir). Str("keyPath", keyPath). @@ -135,7 +146,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) { log.Debug(). Str("dataDir", dataDir). Str("keyPath", keyPath). - Bool("sameAsOldPath", dataDir == "/etc/pulse"). + Bool("sameAsOldPath", dataDir == oldKeyDir). Msg("Skipping key migration check (dataDir is /etc/pulse or paths match)") } @@ -171,7 +182,7 @@ func getOrCreateKeyAt(dataDir string) ([]byte, error) { // Generate new key (only if no encrypted data exists) key := make([]byte, 32) // AES-256 - if _, err := io.ReadFull(rand.Reader, key); err != nil { + if _, err := io.ReadFull(randReader, key); err != nil { return nil, fmt.Errorf("failed to generate key: %w", err) } @@ -205,18 +216,18 @@ func (c *CryptoManager) Encrypt(plaintext []byte) ([]byte, error) { } } - block, err := aes.NewCipher(c.key) + block, err := newCipher(c.key) if err != nil { return nil, err } - gcm, err := cipher.NewGCM(block) + gcm, err := newGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + if _, err := io.ReadFull(randReader, nonce); err != nil { return nil, err } @@ -226,12 +237,12 @@ func (c *CryptoManager) Encrypt(plaintext []byte) ([]byte, error) { // Decrypt decrypts data using AES-GCM func (c *CryptoManager) Decrypt(ciphertext []byte) ([]byte, error) { - block, err := aes.NewCipher(c.key) + block, err := newCipher(c.key) if err != nil { return nil, err } - gcm, err := cipher.NewGCM(block) + gcm, err := newGCM(block) if err != nil { return nil, err } diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go index d32c091f4..047c8ad8f 100644 --- a/internal/crypto/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -2,11 +2,51 @@ package crypto import ( "bytes" + "crypto/cipher" + "encoding/base64" + "errors" + "io" "os" "path/filepath" "testing" ) +type errReader struct { + err error +} + +func (e errReader) Read(p []byte) (int, error) { + return 0, e.err +} + +func withDefaultDataDir(t *testing.T, dir string) { + t.Helper() + orig := defaultDataDirFn + defaultDataDirFn = func() string { return dir } + t.Cleanup(func() { defaultDataDirFn = orig }) +} + +func withLegacyKeyPath(t *testing.T, path string) { + t.Helper() + orig := legacyKeyPath + legacyKeyPath = path + t.Cleanup(func() { legacyKeyPath = orig }) +} + +func withRandReader(t *testing.T, r io.Reader) { + t.Helper() + orig := randReader + randReader = r + t.Cleanup(func() { randReader = orig }) +} + +func withNewGCM(t *testing.T, fn func(cipher.Block) (cipher.AEAD, error)) { + t.Helper() + orig := newGCM + newGCM = fn + t.Cleanup(func() { newGCM = orig }) +} + func TestEncryptDecrypt(t *testing.T) { // Create a temp directory for the test tmpDir := t.TempDir() @@ -149,6 +189,33 @@ func TestEncryptionKeyFilePermissions(t *testing.T) { } } +func TestNewCryptoManagerAt_DefaultDataDir(t *testing.T) { + tmpDir := t.TempDir() + withDefaultDataDir(t, tmpDir) + + cm, err := NewCryptoManagerAt("") + if err != nil { + t.Fatalf("NewCryptoManagerAt() error: %v", err) + } + if cm.keyPath != filepath.Join(tmpDir, ".encryption.key") { + t.Fatalf("keyPath = %q, want %q", cm.keyPath, filepath.Join(tmpDir, ".encryption.key")) + } +} + +func TestNewCryptoManagerAt_KeyError(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + err := os.WriteFile(filepath.Join(tmpDir, "nodes.enc"), []byte("data"), 0600) + if err != nil { + t.Fatalf("Failed to create encrypted file: %v", err) + } + + _, err = NewCryptoManagerAt(tmpDir) + if err == nil { + t.Fatal("Expected error when encrypted data exists without a key") + } +} + func TestDecryptInvalidData(t *testing.T) { tmpDir := t.TempDir() cm, err := NewCryptoManagerAt(tmpDir) @@ -175,6 +242,217 @@ func TestDecryptInvalidData(t *testing.T) { } } +func TestGetOrCreateKeyAt_InvalidBase64(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + + keyPath := filepath.Join(tmpDir, ".encryption.key") + if err := os.WriteFile(keyPath, []byte("not-base64"), 0600); err != nil { + t.Fatalf("Failed to write key file: %v", err) + } + + key, err := getOrCreateKeyAt(tmpDir) + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if len(key) != 32 { + t.Fatalf("expected 32-byte key, got %d", len(key)) + } +} + +func TestGetOrCreateKeyAt_DefaultDataDir(t *testing.T) { + tmpDir := t.TempDir() + withDefaultDataDir(t, tmpDir) + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + + key, err := getOrCreateKeyAt("") + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if len(key) != 32 { + t.Fatalf("expected 32-byte key, got %d", len(key)) + } +} + +func TestGetOrCreateKeyAt_InvalidLength(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + + shortKey := make([]byte, 16) + for i := range shortKey { + shortKey[i] = byte(i) + } + encoded := base64.StdEncoding.EncodeToString(shortKey) + if err := os.WriteFile(filepath.Join(tmpDir, ".encryption.key"), []byte(encoded), 0600); err != nil { + t.Fatalf("Failed to write key file: %v", err) + } + + key, err := getOrCreateKeyAt(tmpDir) + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if len(key) != 32 { + t.Fatalf("expected 32-byte key, got %d", len(key)) + } +} + +func TestGetOrCreateKeyAt_SkipMigrationWhenPathsMatch(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(tmpDir, ".encryption.key")) + + key, err := getOrCreateKeyAt(tmpDir) + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if len(key) != 32 { + t.Fatalf("expected 32-byte key, got %d", len(key)) + } +} + +func TestGetOrCreateKeyAt_MigrateSuccess(t *testing.T) { + legacyDir := t.TempDir() + legacyPath := filepath.Join(legacyDir, ".encryption.key") + withLegacyKeyPath(t, legacyPath) + + oldKey := make([]byte, 32) + for i := range oldKey { + oldKey[i] = byte(i) + } + encoded := base64.StdEncoding.EncodeToString(oldKey) + if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil { + t.Fatalf("Failed to write legacy key: %v", err) + } + + newDir := t.TempDir() + key, err := getOrCreateKeyAt(newDir) + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if !bytes.Equal(key, oldKey) { + t.Fatalf("migrated key mismatch") + } + contents, err := os.ReadFile(filepath.Join(newDir, ".encryption.key")) + if err != nil { + t.Fatalf("Failed to read migrated key: %v", err) + } + if string(contents) != encoded { + t.Fatalf("migrated key contents mismatch") + } +} + +func TestGetOrCreateKeyAt_MigrateMkdirError(t *testing.T) { + legacyDir := t.TempDir() + legacyPath := filepath.Join(legacyDir, ".encryption.key") + withLegacyKeyPath(t, legacyPath) + + oldKey := make([]byte, 32) + for i := range oldKey { + oldKey[i] = byte(i) + } + encoded := base64.StdEncoding.EncodeToString(oldKey) + if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil { + t.Fatalf("Failed to write legacy key: %v", err) + } + + tmpDir := t.TempDir() + dataFile := filepath.Join(tmpDir, "datafile") + if err := os.WriteFile(dataFile, []byte("x"), 0600); err != nil { + t.Fatalf("Failed to write data file: %v", err) + } + + key, err := getOrCreateKeyAt(dataFile) + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if !bytes.Equal(key, oldKey) { + t.Fatalf("expected legacy key on mkdir error") + } +} + +func TestGetOrCreateKeyAt_MigrateWriteError(t *testing.T) { + legacyDir := t.TempDir() + legacyPath := filepath.Join(legacyDir, ".encryption.key") + withLegacyKeyPath(t, legacyPath) + + oldKey := make([]byte, 32) + for i := range oldKey { + oldKey[i] = byte(i) + } + encoded := base64.StdEncoding.EncodeToString(oldKey) + if err := os.WriteFile(legacyPath, []byte(encoded), 0600); err != nil { + t.Fatalf("Failed to write legacy key: %v", err) + } + + newDir := t.TempDir() + keyPath := filepath.Join(newDir, ".encryption.key") + if err := os.MkdirAll(keyPath, 0700); err != nil { + t.Fatalf("Failed to create key path dir: %v", err) + } + + key, err := getOrCreateKeyAt(newDir) + if err != nil { + t.Fatalf("getOrCreateKeyAt() error: %v", err) + } + if !bytes.Equal(key, oldKey) { + t.Fatalf("expected legacy key on write error") + } +} + +func TestGetOrCreateKeyAt_EncryptedDataExists(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + + if err := os.WriteFile(filepath.Join(tmpDir, "nodes.enc"), []byte("data"), 0600); err != nil { + t.Fatalf("Failed to write encrypted file: %v", err) + } + + _, err := getOrCreateKeyAt(tmpDir) + if err == nil { + t.Fatal("Expected error when encrypted data exists") + } +} + +func TestGetOrCreateKeyAt_RandReaderError(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + withRandReader(t, errReader{err: errors.New("read failed")}) + + _, err := getOrCreateKeyAt(tmpDir) + if err == nil { + t.Fatal("Expected error from rand reader") + } +} + +func TestGetOrCreateKeyAt_CreateDirError(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + + dataFile := filepath.Join(tmpDir, "datafile") + if err := os.WriteFile(dataFile, []byte("x"), 0600); err != nil { + t.Fatalf("Failed to write data file: %v", err) + } + + _, err := getOrCreateKeyAt(dataFile) + if err == nil { + t.Fatal("Expected error when creating directory") + } +} + +func TestGetOrCreateKeyAt_SaveKeyError(t *testing.T) { + tmpDir := t.TempDir() + withLegacyKeyPath(t, filepath.Join(t.TempDir(), ".encryption.key")) + + keyPath := filepath.Join(tmpDir, ".encryption.key") + if err := os.MkdirAll(keyPath, 0700); err != nil { + t.Fatalf("Failed to create key path dir: %v", err) + } + + _, err := getOrCreateKeyAt(tmpDir) + if err == nil { + t.Fatal("Expected error when saving key") + } +} + func TestDecryptStringInvalidBase64(t *testing.T) { tmpDir := t.TempDir() cm, err := NewCryptoManagerAt(tmpDir) @@ -216,6 +494,57 @@ func TestEncryptionUniqueness(t *testing.T) { } } +func TestEncryptInvalidKey(t *testing.T) { + cm := &CryptoManager{key: []byte("short")} + if _, err := cm.Encrypt([]byte("data")); err == nil { + t.Fatal("Expected error for invalid key length") + } +} + +func TestDecryptInvalidKey(t *testing.T) { + cm := &CryptoManager{key: []byte("short")} + if _, err := cm.Decrypt([]byte("data")); err == nil { + t.Fatal("Expected error for invalid key length") + } +} + +func TestEncryptNonceReadError(t *testing.T) { + withRandReader(t, errReader{err: errors.New("nonce read error")}) + cm := &CryptoManager{key: make([]byte, 32)} + if _, err := cm.Encrypt([]byte("data")); err == nil { + t.Fatal("Expected error reading nonce") + } +} + +func TestEncryptDecryptGCMError(t *testing.T) { + withNewGCM(t, func(cipher.Block) (cipher.AEAD, error) { + return nil, errors.New("gcm error") + }) + + cm := &CryptoManager{key: make([]byte, 32)} + if _, err := cm.Encrypt([]byte("data")); err == nil { + t.Fatal("Expected Encrypt error from GCM") + } + if _, err := cm.Decrypt([]byte("data")); err == nil { + t.Fatal("Expected Decrypt error from GCM") + } +} + +func TestEncryptStringError(t *testing.T) { + cm := &CryptoManager{key: []byte("short")} + if _, err := cm.EncryptString("data"); err == nil { + t.Fatal("Expected EncryptString error") + } +} + +func TestDecryptStringError(t *testing.T) { + cm := &CryptoManager{key: make([]byte, 32)} + encoded := base64.StdEncoding.EncodeToString([]byte("short")) + if _, err := cm.DecryptString(encoded); err == nil { + t.Fatal("Expected DecryptString error") + } +} + func TestNewCryptoManagerRefusesOrphanedData(t *testing.T) { // Skip if production key exists - migration code will always find and use it if _, err := os.Stat("/etc/pulse/.encryption.key"); err == nil { diff --git a/internal/logging/logging.go b/internal/logging/logging.go index f52715622..60428b1d0 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -44,6 +44,28 @@ var ( defaultTimeFmt = time.RFC3339 ) +var ( + nowFn = time.Now + isTerminalFn = term.IsTerminal + mkdirAllFn = os.MkdirAll + openFileFn = os.OpenFile + openFn = os.Open + statFn = os.Stat + readDirFn = os.ReadDir + renameFn = os.Rename + removeFn = os.Remove + copyFn = io.Copy + gzipNewWriterFn = gzip.NewWriter + statFileFn = func(file *os.File) (os.FileInfo, error) { return file.Stat() } + closeFileFn = func(file *os.File) error { return file.Close() } + compressFn = compressAndRemove +) + +var ( + defaultStatFileFn = statFileFn + defaultCloseFileFn = closeFileFn +) + func init() { baseLogger = zerolog.New(baseWriter).With().Timestamp().Logger() log.Logger = baseLogger @@ -137,7 +159,7 @@ func isTerminal(file *os.File) bool { if file == nil { return false } - return term.IsTerminal(int(file.Fd())) + return isTerminalFn(int(file.Fd())) } type rollingFileWriter struct { @@ -157,7 +179,7 @@ func newRollingFileWriter(cfg Config) (io.Writer, error) { } dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { + if err := mkdirAllFn(dir, 0o755); err != nil { return nil, fmt.Errorf("create log directory: %w", err) } @@ -205,13 +227,13 @@ func (w *rollingFileWriter) openOrCreateLocked() error { return nil } - file, err := os.OpenFile(w.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) + file, err := openFileFn(w.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) if err != nil { return fmt.Errorf("open log file: %w", err) } w.file = file - info, err := file.Stat() + info, err := statFileFn(file) if err != nil { w.currentSize = 0 return nil @@ -225,11 +247,11 @@ func (w *rollingFileWriter) rotateLocked() error { return err } - if _, err := os.Stat(w.path); err == nil { - rotated := fmt.Sprintf("%s.%s", w.path, time.Now().Format("20060102-150405")) - if err := os.Rename(w.path, rotated); err == nil { + if _, err := statFn(w.path); err == nil { + rotated := fmt.Sprintf("%s.%s", w.path, nowFn().Format("20060102-150405")) + if err := renameFn(w.path, rotated); err == nil { if w.compress { - go compressAndRemove(rotated) + go compressFn(rotated) } } } @@ -242,7 +264,7 @@ func (w *rollingFileWriter) closeLocked() error { if w.file == nil { return nil } - err := w.file.Close() + err := closeFileFn(w.file) w.file = nil w.currentSize = 0 return err @@ -256,9 +278,9 @@ func (w *rollingFileWriter) cleanupOldFiles() { dir := filepath.Dir(w.path) base := filepath.Base(w.path) prefix := base + "." - cutoff := time.Now().Add(-w.maxAge) + cutoff := nowFn().Add(-w.maxAge) - entries, err := os.ReadDir(dir) + entries, err := readDirFn(dir) if err != nil { return } @@ -273,26 +295,26 @@ func (w *rollingFileWriter) cleanupOldFiles() { continue } if info.ModTime().Before(cutoff) { - _ = os.Remove(filepath.Join(dir, name)) + _ = removeFn(filepath.Join(dir, name)) } } } func compressAndRemove(path string) { - in, err := os.Open(path) + in, err := openFn(path) if err != nil { return } defer in.Close() outPath := path + ".gz" - out, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + out, err := openFileFn(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) if err != nil { return } - gw := gzip.NewWriter(out) - if _, err = io.Copy(gw, in); err != nil { + gw := gzipNewWriterFn(out) + if _, err = copyFn(gw, in); err != nil { gw.Close() out.Close() return @@ -302,5 +324,5 @@ func compressAndRemove(path string) { return } out.Close() - _ = os.Remove(path) + _ = removeFn(path) } diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index 09d4aff43..832ee69a1 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -2,13 +2,19 @@ package logging import ( "bytes" + "compress/gzip" + "errors" + "io" "os" + "path/filepath" "reflect" "sync" "testing" + "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "golang.org/x/term" ) func resetLoggingState() { @@ -21,6 +27,20 @@ func resetLoggingState() { log.Logger = baseLogger zerolog.TimeFieldFormat = defaultTimeFmt zerolog.SetGlobalLevel(zerolog.InfoLevel) + nowFn = time.Now + isTerminalFn = term.IsTerminal + mkdirAllFn = os.MkdirAll + openFileFn = os.OpenFile + openFn = os.Open + statFn = os.Stat + readDirFn = os.ReadDir + renameFn = os.Rename + removeFn = os.Remove + copyFn = io.Copy + gzipNewWriterFn = gzip.NewWriter + statFileFn = defaultStatFileFn + closeFileFn = defaultCloseFileFn + compressFn = compressAndRemove } func TestInitJSONFormatSetsLevelAndComponent(t *testing.T) { @@ -267,6 +287,393 @@ func TestSelectWriter(t *testing.T) { } } +func TestSelectWriterAutoTerminal(t *testing.T) { + t.Cleanup(resetLoggingState) + isTerminalFn = func(int) bool { return true } + + w := selectWriter("auto") + if _, ok := w.(zerolog.ConsoleWriter); !ok { + t.Fatalf("expected console writer, got %#v", w) + } +} + +func TestSelectWriterDefault(t *testing.T) { + t.Cleanup(resetLoggingState) + + w := selectWriter("unknown") + if w != os.Stderr { + t.Fatalf("expected default writer to be os.Stderr, got %#v", w) + } +} + +func TestIsTerminalNil(t *testing.T) { + t.Cleanup(resetLoggingState) + + if isTerminal(nil) { + t.Fatal("expected nil file to report false") + } +} + +func TestNewRollingFileWriter_EmptyPath(t *testing.T) { + t.Cleanup(resetLoggingState) + + writer, err := newRollingFileWriter(Config{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if writer != nil { + t.Fatalf("expected nil writer, got %#v", writer) + } +} + +func TestNewRollingFileWriter_MkdirError(t *testing.T) { + t.Cleanup(resetLoggingState) + + mkdirAllFn = func(string, os.FileMode) error { + return errors.New("mkdir failed") + } + + _, err := newRollingFileWriter(Config{FilePath: "/tmp/logs/test.log"}) + if err == nil { + t.Fatal("expected error from mkdir") + } +} + +func TestInitFileWriterError(t *testing.T) { + t.Cleanup(resetLoggingState) + + mkdirAllFn = func(string, os.FileMode) error { + return errors.New("mkdir failed") + } + + Init(Config{ + Format: "json", + FilePath: "/tmp/logs/test.log", + }) +} + +func TestNewRollingFileWriter_DefaultMaxSize(t *testing.T) { + t.Cleanup(resetLoggingState) + + dir := t.TempDir() + writer, err := newRollingFileWriter(Config{ + FilePath: filepath.Join(dir, "app.log"), + MaxSizeMB: 0, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + w, ok := writer.(*rollingFileWriter) + if !ok { + t.Fatalf("expected rollingFileWriter, got %#v", writer) + } + if w.maxBytes != 100*1024*1024 { + t.Fatalf("expected default max bytes, got %d", w.maxBytes) + } + _ = w.closeLocked() +} + +func TestNewRollingFileWriter_OpenError(t *testing.T) { + t.Cleanup(resetLoggingState) + + openFileFn = func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("open failed") + } + + _, err := newRollingFileWriter(Config{FilePath: filepath.Join(t.TempDir(), "app.log")}) + if err == nil { + t.Fatal("expected error from openOrCreateLocked") + } +} + +func TestOpenOrCreateLocked_StatError(t *testing.T) { + t.Cleanup(resetLoggingState) + + dir := t.TempDir() + w := &rollingFileWriter{path: filepath.Join(dir, "app.log")} + statFileFn = func(*os.File) (os.FileInfo, error) { + return nil, errors.New("stat failed") + } + if err := w.openOrCreateLocked(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if w.currentSize != 0 { + t.Fatalf("expected current size 0, got %d", w.currentSize) + } + _ = w.closeLocked() +} + +func TestOpenOrCreateLocked_AlreadyOpen(t *testing.T) { + t.Cleanup(resetLoggingState) + + file, err := os.CreateTemp(t.TempDir(), "log") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + w := &rollingFileWriter{path: file.Name(), file: file} + if err := w.openOrCreateLocked(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + _ = w.closeLocked() +} + +func TestRollingFileWriter_WriteOpenError(t *testing.T) { + t.Cleanup(resetLoggingState) + + openFileFn = func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("open failed") + } + w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log")} + if _, err := w.Write([]byte("data")); err == nil { + t.Fatal("expected write error") + } +} + +func TestRollingFileWriter_WriteRotateError(t *testing.T) { + t.Cleanup(resetLoggingState) + + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + callCount := 0 + openFileFn = func(name string, flag int, perm os.FileMode) (*os.File, error) { + callCount++ + if callCount == 1 { + return os.OpenFile(name, flag, perm) + } + return nil, errors.New("open failed") + } + w := &rollingFileWriter{path: path, maxBytes: 1} + if _, err := w.Write([]byte("too big")); err == nil { + t.Fatal("expected rotate error") + } +} + +func TestRollingFileWriter_RotateCompress(t *testing.T) { + t.Cleanup(resetLoggingState) + + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + if err := os.WriteFile(path, []byte("data"), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + + w := &rollingFileWriter{path: path, maxBytes: 1, compress: true} + if err := w.openOrCreateLocked(); err != nil { + t.Fatalf("openOrCreateLocked error: %v", err) + } + + ch := make(chan string, 1) + compressFn = func(p string) { ch <- p } + + if err := w.rotateLocked(); err != nil { + t.Fatalf("rotateLocked error: %v", err) + } + + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("expected compress to be triggered") + } + _ = w.closeLocked() +} + +func TestRotateLockedCloseError(t *testing.T) { + t.Cleanup(resetLoggingState) + + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + t.Fatalf("failed to open file: %v", err) + } + w := &rollingFileWriter{path: path, file: file} + closeFileFn = func(*os.File) error { + return errors.New("close failed") + } + + if err := w.rotateLocked(); err == nil { + t.Fatal("expected close error") + } + _ = file.Close() +} + +func TestCloseLocked(t *testing.T) { + t.Cleanup(resetLoggingState) + + w := &rollingFileWriter{} + if err := w.closeLocked(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + file, err := os.CreateTemp(t.TempDir(), "log") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + w.file = file + if err := w.closeLocked(); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + if w.file != nil { + t.Fatal("expected file to be cleared") + } + if w.currentSize != 0 { + t.Fatalf("expected size reset, got %d", w.currentSize) + } +} + +func TestCleanupOldFilesNoMaxAge(t *testing.T) { + t.Cleanup(resetLoggingState) + + w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: 0} + w.cleanupOldFiles() +} + +func TestCleanupOldFilesReadDirError(t *testing.T) { + t.Cleanup(resetLoggingState) + + readDirFn = func(string) ([]os.DirEntry, error) { + return nil, errors.New("read dir failed") + } + w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: time.Hour} + w.cleanupOldFiles() +} + +func TestCleanupOldFilesInfoError(t *testing.T) { + t.Cleanup(resetLoggingState) + + readDirFn = func(string) ([]os.DirEntry, error) { + return []os.DirEntry{errDirEntry{name: "app.log.20200101"}}, nil + } + w := &rollingFileWriter{path: filepath.Join(t.TempDir(), "app.log"), maxAge: time.Hour} + w.cleanupOldFiles() +} + +func TestCleanupOldFilesRemovesOld(t *testing.T) { + t.Cleanup(resetLoggingState) + + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + oldFile := filepath.Join(dir, "app.log.20200101-000000") + newFile := filepath.Join(dir, "app.log.20250101-000000") + otherFile := filepath.Join(dir, "other.log.20200101") + + if err := os.WriteFile(oldFile, []byte("old"), 0600); err != nil { + t.Fatalf("failed to write old file: %v", err) + } + if err := os.WriteFile(newFile, []byte("new"), 0600); err != nil { + t.Fatalf("failed to write new file: %v", err) + } + if err := os.WriteFile(otherFile, []byte("other"), 0600); err != nil { + t.Fatalf("failed to write other file: %v", err) + } + + fixedNow := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC) + nowFn = func() time.Time { return fixedNow } + + if err := os.Chtimes(oldFile, fixedNow.Add(-48*time.Hour), fixedNow.Add(-48*time.Hour)); err != nil { + t.Fatalf("failed to set old file time: %v", err) + } + if err := os.Chtimes(newFile, fixedNow.Add(-time.Hour), fixedNow.Add(-time.Hour)); err != nil { + t.Fatalf("failed to set new file time: %v", err) + } + + w := &rollingFileWriter{path: path, maxAge: 24 * time.Hour} + w.cleanupOldFiles() + + if _, err := os.Stat(oldFile); !os.IsNotExist(err) { + t.Fatalf("expected old file to be removed") + } + if _, err := os.Stat(newFile); err != nil { + t.Fatalf("expected new file to remain: %v", err) + } + if _, err := os.Stat(otherFile); err != nil { + t.Fatalf("expected other file to remain: %v", err) + } +} + +func TestStatFileFnDefault(t *testing.T) { + t.Cleanup(resetLoggingState) + + file, err := os.CreateTemp(t.TempDir(), "log") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + t.Cleanup(func() { _ = file.Close() }) + + if _, err := statFileFn(file); err != nil { + t.Fatalf("statFileFn error: %v", err) + } +} + +func TestCompressAndRemove(t *testing.T) { + t.Run("OpenError", func(t *testing.T) { + t.Cleanup(resetLoggingState) + openFn = func(string) (*os.File, error) { + return nil, errors.New("open failed") + } + compressAndRemove("/does/not/exist") + }) + + t.Run("OpenFileError", func(t *testing.T) { + t.Cleanup(resetLoggingState) + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + if err := os.WriteFile(path, []byte("data"), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + openFileFn = func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("open file failed") + } + compressAndRemove(path) + }) + + t.Run("CopyError", func(t *testing.T) { + t.Cleanup(resetLoggingState) + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + if err := os.WriteFile(path, []byte("data"), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + copyFn = func(io.Writer, io.Reader) (int64, error) { + return 0, errors.New("copy failed") + } + compressAndRemove(path) + }) + + t.Run("CloseError", func(t *testing.T) { + t.Cleanup(resetLoggingState) + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + if err := os.WriteFile(path, []byte("data"), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + errWriter := errWriteCloser{err: errors.New("write failed")} + gzipNewWriterFn = func(io.Writer) *gzip.Writer { + return gzip.NewWriter(errWriter) + } + copyFn = func(io.Writer, io.Reader) (int64, error) { + return 0, nil + } + compressAndRemove(path) + }) + + t.Run("Success", func(t *testing.T) { + t.Cleanup(resetLoggingState) + dir := t.TempDir() + path := filepath.Join(dir, "app.log") + if err := os.WriteFile(path, []byte("data"), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + compressAndRemove(path) + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatal("expected original file to be removed") + } + if _, err := os.Stat(path + ".gz"); err != nil { + t.Fatalf("expected gzip file to exist: %v", err) + } + }) +} + // Test that the logging package doesn't panic under concurrent use func TestConcurrentLogging(t *testing.T) { t.Cleanup(resetLoggingState) @@ -292,3 +699,24 @@ func TestConcurrentLogging(t *testing.T) { t.Fatal("expected log output from concurrent logging") } } + +type errDirEntry struct { + name string +} + +func (e errDirEntry) Name() string { return e.name } +func (e errDirEntry) IsDir() bool { return false } +func (e errDirEntry) Type() os.FileMode { + return 0 +} +func (e errDirEntry) Info() (os.FileInfo, error) { + return nil, errors.New("info error") +} + +type errWriteCloser struct { + err error +} + +func (e errWriteCloser) Write(p []byte) (int, error) { + return 0, e.err +} diff --git a/internal/mdadm/mdadm.go b/internal/mdadm/mdadm.go index 6d906500e..c74b8c813 100644 --- a/internal/mdadm/mdadm.go +++ b/internal/mdadm/mdadm.go @@ -14,9 +14,13 @@ import ( // Pre-compiled regexes for performance (avoid recompilation on each call) var ( - mdDeviceRe = regexp.MustCompile(`^(md\d+)\s*:`) - slotRe = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(.+?)\s+(/dev/.+)$`) - speedRe = regexp.MustCompile(`speed=(\S+)`) + mdDeviceRe = regexp.MustCompile(`^(md\d+)\s*:`) + slotRe = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(.+?)\s+(/dev/.+)$`) + speedRe = regexp.MustCompile(`speed=(\S+)`) + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + return cmd.Output() + } ) // CollectArrays discovers and collects status for all mdadm RAID arrays on the system. @@ -56,8 +60,8 @@ func isMdadmAvailable(ctx context.Context) bool { ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, "mdadm", "--version") - return cmd.Run() == nil + _, err := runCommandOutput(ctx, "mdadm", "--version") + return err == nil } // listArrayDevices scans /proc/mdstat to find all md devices @@ -65,8 +69,7 @@ func listArrayDevices(ctx context.Context) ([]string, error) { ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, "cat", "/proc/mdstat") - output, err := cmd.Output() + output, err := runCommandOutput(ctx, "cat", "/proc/mdstat") if err != nil { return nil, fmt.Errorf("read /proc/mdstat: %w", err) } @@ -89,8 +92,7 @@ func collectArrayDetail(ctx context.Context, device string) (host.RAIDArray, err ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, "mdadm", "--detail", device) - output, err := cmd.Output() + output, err := runCommandOutput(ctx, "mdadm", "--detail", device) if err != nil { return host.RAIDArray{}, fmt.Errorf("mdadm --detail %s: %w", device, err) } @@ -220,8 +222,7 @@ func getRebuildSpeed(device string) string { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, "cat", "/proc/mdstat") - output, err := cmd.Output() + output, err := runCommandOutput(ctx, "cat", "/proc/mdstat") if err != nil { return "" } diff --git a/internal/mdadm/mdadm_test.go b/internal/mdadm/mdadm_test.go index c29e1f087..d8ef550ed 100644 --- a/internal/mdadm/mdadm_test.go +++ b/internal/mdadm/mdadm_test.go @@ -1,11 +1,20 @@ package mdadm import ( + "context" + "errors" "testing" "github.com/rcourtman/pulse-go-rewrite/pkg/agents/host" ) +func withRunCommandOutput(t *testing.T, fn func(ctx context.Context, name string, args ...string) ([]byte, error)) { + t.Helper() + orig := runCommandOutput + runCommandOutput = fn + t.Cleanup(func() { runCommandOutput = orig }) +} + func TestParseDetail(t *testing.T) { tests := []struct { name string @@ -540,3 +549,239 @@ Consistency Policy : resync }) } } + +func TestIsMdadmAvailable(t *testing.T) { + t.Run("available", func(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte("mdadm"), nil + }) + + if !isMdadmAvailable(context.Background()) { + t.Fatal("expected mdadm available") + } + }) + + t.Run("missing", func(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("missing") + }) + + if isMdadmAvailable(context.Background()) { + t.Fatal("expected mdadm unavailable") + } + }) +} + +func TestListArrayDevices(t *testing.T) { + mdstat := `Personalities : [raid1] [raid6] +md0 : active raid1 sdb1[1] sda1[0] +md1 : active raid6 sdc1[2] sdb1[1] sda1[0] +unused devices: ` + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(mdstat), nil + }) + + devices, err := listArrayDevices(context.Background()) + if err != nil { + t.Fatalf("listArrayDevices error: %v", err) + } + if len(devices) != 2 || devices[0] != "/dev/md0" || devices[1] != "/dev/md1" { + t.Fatalf("unexpected devices: %v", devices) + } +} + +func TestListArrayDevicesError(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("read failed") + }) + + if _, err := listArrayDevices(context.Background()); err == nil { + t.Fatal("expected error") + } +} + +func TestCollectArrayDetailError(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("detail failed") + }) + + if _, err := collectArrayDetail(context.Background(), "/dev/md0"); err == nil { + t.Fatal("expected error") + } +} + +func TestCollectArraysNotAvailable(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("missing") + }) + + arrays, err := CollectArrays(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if arrays != nil { + t.Fatalf("expected nil arrays, got %v", arrays) + } +} + +func TestCollectArraysListError(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name == "mdadm" { + return []byte("mdadm"), nil + } + return nil, errors.New("read failed") + }) + + if _, err := CollectArrays(context.Background()); err == nil { + t.Fatal("expected error") + } +} + +func TestCollectArraysNoDevices(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name == "mdadm" { + return []byte("mdadm"), nil + } + return []byte("unused devices: "), nil + }) + + arrays, err := CollectArrays(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if arrays != nil { + t.Fatalf("expected nil arrays, got %v", arrays) + } +} + +func TestCollectArraysSkipsDetailError(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + switch name { + case "mdadm": + if len(args) > 0 && args[0] == "--version" { + return []byte("mdadm"), nil + } + return nil, errors.New("detail failed") + case "cat": + return []byte("md0 : active raid1 sda1[0]"), nil + default: + return nil, errors.New("unexpected") + } + }) + + arrays, err := CollectArrays(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(arrays) != 0 { + t.Fatalf("expected empty arrays, got %v", arrays) + } +} + +func TestCollectArraysSuccess(t *testing.T) { + detail := `/dev/md0: + Raid Level : raid1 + State : clean + Total Devices : 2 + Active Devices : 2 + Working Devices : 2 + Failed Devices : 0 + Spare Devices : 0 + + Number Major Minor RaidDevice State + 0 8 1 0 active sync /dev/sda1` + + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + switch name { + case "mdadm": + if len(args) > 0 && args[0] == "--version" { + return []byte("mdadm"), nil + } + return []byte(detail), nil + case "cat": + return []byte("md0 : active raid1 sda1[0]"), nil + default: + return nil, errors.New("unexpected") + } + }) + + arrays, err := CollectArrays(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(arrays) != 1 || arrays[0].Device != "/dev/md0" { + t.Fatalf("unexpected arrays: %v", arrays) + } +} + +func TestGetRebuildSpeed(t *testing.T) { + mdstat := `md0 : active raid1 sda1[0] sdb1[1] + [>....................] recovery = 12.6% (37043392/293039104) finish=127.5min speed=33440K/sec +` + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(mdstat), nil + }) + + if speed := getRebuildSpeed("/dev/md0"); speed != "33440K/sec" { + t.Fatalf("unexpected speed: %s", speed) + } +} + +func TestGetRebuildSpeedNoMatch(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte("md0 : active raid1 sda1[0]"), nil + }) + + if speed := getRebuildSpeed("/dev/md0"); speed != "" { + t.Fatalf("expected empty speed, got %s", speed) + } +} + +func TestGetRebuildSpeedError(t *testing.T) { + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("read failed") + }) + + if speed := getRebuildSpeed("/dev/md0"); speed != "" { + t.Fatalf("expected empty speed, got %s", speed) + } +} + +func TestParseDetailSetsRebuildSpeed(t *testing.T) { + output := `/dev/md0: + Raid Level : raid1 + State : clean + Rebuild Status : 12% complete + + Number Major Minor RaidDevice State + 0 8 1 0 active sync /dev/sda1` + + mdstat := `md0 : active raid1 sda1[0] + [>....................] recovery = 12.6% (37043392/293039104) finish=127.5min speed=1234K/sec +` + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(mdstat), nil + }) + + array, err := parseDetail("/dev/md0", output) + if err != nil { + t.Fatalf("parseDetail error: %v", err) + } + if array.RebuildSpeed != "1234K/sec" { + t.Fatalf("expected rebuild speed, got %s", array.RebuildSpeed) + } +} + +func TestGetRebuildSpeedSectionExit(t *testing.T) { + mdstat := `md0 : active raid1 sda1[0] + [>....................] recovery = 12.6% (37043392/293039104) finish=127.5min +md1 : active raid1 sdb1[0] +` + withRunCommandOutput(t, func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(mdstat), nil + }) + + if speed := getRebuildSpeed("/dev/md0"); speed != "" { + t.Fatalf("expected empty speed, got %s", speed) + } +} diff --git a/internal/resources/converters_coverage_test.go b/internal/resources/converters_coverage_test.go new file mode 100644 index 000000000..0fcbbde45 --- /dev/null +++ b/internal/resources/converters_coverage_test.go @@ -0,0 +1,578 @@ +package resources + +import ( + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +func TestFromNodeTemperatureAndCluster(t *testing.T) { + now := time.Now() + node := models.Node{ + ID: "node-1", + Name: "node-1", + Instance: "pve1", + Status: "online", + CPU: 0.5, + Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50}, + Disk: models.Disk{Total: 200, Used: 100, Free: 100, Usage: 50}, + Uptime: 10, + LastSeen: now, + IsClusterMember: true, + ClusterName: "cluster-a", + Temperature: &models.Temperature{ + Available: true, + HasCPU: true, + CPUPackage: 72.5, + }, + } + r := FromNode(node) + if r.Temperature == nil || *r.Temperature != 72.5 { + t.Fatalf("expected CPU package temperature") + } + if r.ClusterID != "pve-cluster/cluster-a" { + t.Fatalf("expected cluster ID") + } + + node2 := node + node2.ID = "node-2" + node2.ClusterName = "" + node2.IsClusterMember = false + node2.Temperature = &models.Temperature{ + Available: true, + HasCPU: true, + Cores: []models.CoreTemp{ + {Core: 0, Temp: 60}, + {Core: 1, Temp: 80}, + }, + } + r2 := FromNode(node2) + if r2.Temperature == nil || *r2.Temperature != 70 { + t.Fatalf("expected averaged core temperature") + } + if r2.ClusterID != "" { + t.Fatalf("expected empty cluster ID when not a member") + } +} + +func TestFromHostAndDockerHost(t *testing.T) { + now := time.Now() + host := models.Host{ + ID: "host-1", + Hostname: "host-1", + Status: "degraded", + Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50}, + Disks: []models.Disk{ + {Total: 100, Used: 60, Free: 40, Usage: 60}, + }, + NetworkInterfaces: []models.HostNetworkInterface{ + {Addresses: []string{"10.0.0.1"}, RXBytes: 10, TXBytes: 20}, + }, + Sensors: models.HostSensorSummary{ + TemperatureCelsius: map[string]float64{"cpu": 55}, + }, + LastSeen: now, + } + hr := FromHost(host) + if hr.Temperature == nil || *hr.Temperature != 55 { + t.Fatalf("expected host temperature") + } + if hr.Network == nil || hr.Network.RXBytes != 10 || hr.Network.TXBytes != 20 { + t.Fatalf("expected host network totals") + } + if hr.Disk == nil || hr.Disk.Current == 0 { + t.Fatalf("expected host disk metrics") + } + if hr.Status != StatusDegraded { + t.Fatalf("expected degraded host status") + } + + dockerHost := models.DockerHost{ + ID: "docker-1", + AgentID: "agent-1", + Hostname: "docker-1", + DisplayName: "docker-1", + CustomDisplayName: "custom", + Status: "offline", + Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50}, + Disks: []models.Disk{ + {Total: 100, Used: 50, Free: 50, Usage: 50}, + }, + NetworkInterfaces: []models.HostNetworkInterface{ + {Addresses: []string{"10.0.0.2"}, RXBytes: 1, TXBytes: 2}, + }, + Swarm: &models.DockerSwarmInfo{ + ClusterID: "swarm-1", + }, + LastSeen: now, + } + dr := FromDockerHost(dockerHost) + if dr.DisplayName != "custom" { + t.Fatalf("expected custom display name") + } + if dr.ClusterID != "docker-swarm/swarm-1" { + t.Fatalf("expected docker swarm cluster ID") + } + if dr.Status != StatusOffline { + t.Fatalf("expected docker host status offline") + } + if dr.Identity == nil || len(dr.Identity.IPs) != 1 || dr.Identity.IPs[0] != "10.0.0.2" { + t.Fatalf("expected docker host identity IPs") + } +} + +func TestFromHostPlatformData(t *testing.T) { + now := time.Now() + host := models.Host{ + ID: "host-2", + Hostname: "host-2", + Status: "online", + DiskIO: []models.DiskIO{ + {Device: "sda", ReadBytes: 100, WriteBytes: 200, ReadOps: 1, WriteOps: 2, ReadTime: 3, WriteTime: 4, IOTime: 5}, + }, + RAID: []models.HostRAIDArray{ + { + Device: "md0", + Level: "raid1", + State: "clean", + TotalDevices: 2, + ActiveDevices: 2, + WorkingDevices: 2, + FailedDevices: 0, + SpareDevices: 0, + Devices: []models.HostRAIDDevice{ + {Device: "sda1", State: "active", Slot: 0}, + }, + RebuildPercent: 0, + }, + }, + LastSeen: now, + } + + r := FromHost(host) + var pd HostPlatformData + if err := r.GetPlatformData(&pd); err != nil { + t.Fatalf("failed to get platform data: %v", err) + } + if len(pd.DiskIO) != 1 || len(pd.RAID) != 1 { + t.Fatalf("expected disk IO and RAID entries") + } +} + +func TestFromDockerContainerPodman(t *testing.T) { + now := time.Now() + container := models.DockerContainer{ + ID: "container-1", + Name: "app", + Image: "app:latest", + State: "paused", + Status: "Paused", + CPUPercent: 5, + MemoryUsage: 50, + MemoryLimit: 100, + MemoryPercent: 50, + UptimeSeconds: 10, + Ports: []models.DockerContainerPort{ + {PrivatePort: 80, PublicPort: 8080, Protocol: "tcp", IP: "0.0.0.0"}, + }, + Networks: []models.DockerContainerNetworkLink{ + {Name: "bridge", IPv4: "172.17.0.2"}, + }, + Podman: &models.DockerPodmanContainer{ + PodName: "pod", + }, + CreatedAt: now, + } + + r := FromDockerContainer(container, "host-1", "host-1") + if r.Status != StatusPaused { + t.Fatalf("expected paused status") + } + if r.Memory == nil || r.Memory.Current != 50 { + t.Fatalf("expected container memory metrics") + } + if r.ID != "host-1/container-1" { + t.Fatalf("expected container resource ID to be host-1/container-1") + } +} + +func TestFromVMWithBackup(t *testing.T) { + now := time.Now() + vm := models.VM{ + ID: "vm-1", + VMID: 100, + Name: "vm-1", + Node: "node-1", + Instance: "pve1", + Status: "running", + CPU: 0.25, + LastBackup: now, + LastSeen: now, + } + + r := FromVM(vm) + var pd VMPlatformData + if err := r.GetPlatformData(&pd); err != nil { + t.Fatalf("failed to get platform data: %v", err) + } + if pd.LastBackup == nil || !pd.LastBackup.Equal(now) { + t.Fatalf("expected last backup to be set") + } +} + +func TestFromContainerOCI(t *testing.T) { + now := time.Now() + ct := models.Container{ + ID: "ct-1", + VMID: 200, + Name: "ct-1", + Node: "node-1", + Instance: "pve1", + Status: "paused", + Type: "lxc", + IsOCI: true, + CPU: 0.1, + Memory: models.Memory{Total: 100, Used: 50, Free: 50, Usage: 50}, + Disk: models.Disk{Total: 200, Used: 100, Free: 100, Usage: 50}, + LastBackup: now, + LastSeen: now, + } + + r := FromContainer(ct) + if r.Type != ResourceTypeOCIContainer { + t.Fatalf("expected OCI container type") + } + var pd ContainerPlatformData + if err := r.GetPlatformData(&pd); err != nil { + t.Fatalf("failed to get platform data: %v", err) + } + if !pd.IsOCI || pd.LastBackup == nil { + t.Fatalf("expected OCI platform data with backup") + } +} + +func TestKubernetesConversions(t *testing.T) { + now := time.Now() + cluster := models.KubernetesCluster{ + ID: "cluster-1", + AgentID: "agent-1", + CustomDisplayName: "custom", + Status: "online", + LastSeen: now, + Nodes: []models.KubernetesNode{ + {Name: "node-1", Ready: true, Unschedulable: true}, + }, + Pods: []models.KubernetesPod{ + {Name: "pod-1", Namespace: "default", Phase: "Succeeded", NodeName: "node-1"}, + }, + Deployments: []models.KubernetesDeployment{ + {Name: "dep-1", Namespace: "default", DesiredReplicas: 1, AvailableReplicas: 1}, + }, + } + + cr := FromKubernetesCluster(cluster) + if cr.Name != "custom" || cr.DisplayName != "custom" { + t.Fatalf("expected custom display name to be used") + } + + node := cluster.Nodes[0] + nr := FromKubernetesNode(node, cluster) + if nr.Status != StatusDegraded { + t.Fatalf("expected unschedulable node to be degraded") + } + if !strings.Contains(nr.ID, "cluster-1/node/") { + t.Fatalf("expected node ID to include cluster and node") + } + + pod := cluster.Pods[0] + pr := FromKubernetesPod(pod, cluster) + if pr.Status != StatusStopped { + t.Fatalf("expected succeeded pod to be stopped") + } + if !strings.Contains(pr.ParentID, "cluster-1/node/") { + t.Fatalf("expected pod parent to be node") + } + + dep := cluster.Deployments[0] + dr := FromKubernetesDeployment(dep, cluster) + if dr.Status != StatusRunning { + t.Fatalf("expected deployment running") + } + + emptyCluster := models.KubernetesCluster{ + ID: "cluster-2", + AgentID: "agent-2", + Status: "offline", + LastSeen: now, + } + er := FromKubernetesCluster(emptyCluster) + if er.Name != "cluster-2" || er.DisplayName != "cluster-2" { + t.Fatalf("expected cluster name fallback to ID") + } + if er.Status != StatusOffline { + t.Fatalf("expected offline cluster status") + } +} + +func TestKubernetesNodeNotReady(t *testing.T) { + now := time.Now() + cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now} + node := models.KubernetesNode{ + UID: "node-uid", + Name: "node-offline", + Ready: false, + } + r := FromKubernetesNode(node, cluster) + if r.Status != StatusOffline { + t.Fatalf("expected offline node status") + } + if !strings.Contains(r.ID, "node-uid") { + t.Fatalf("expected node UID in resource ID") + } +} + +func TestKubernetesPodParentCluster(t *testing.T) { + now := time.Now() + cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now} + pod := models.KubernetesPod{ + UID: "pod-uid", + Name: "pod-1", + Namespace: "default", + Phase: "Pending", + } + + r := FromKubernetesPod(pod, cluster) + if r.ParentID != cluster.ID { + t.Fatalf("expected pod parent to be cluster ID") + } + if !strings.Contains(r.ID, "pod-uid") { + t.Fatalf("expected pod UID in resource ID") + } +} + +func TestKubernetesPodContainers(t *testing.T) { + now := time.Now() + cluster := models.KubernetesCluster{ID: "cluster-3", AgentID: "agent-3", LastSeen: now} + pod := models.KubernetesPod{ + Name: "pod-2", + Namespace: "default", + NodeName: "node-2", + Phase: "Running", + Containers: []models.KubernetesPodContainer{ + {Name: "c1", Image: "busybox", Ready: true, RestartCount: 1, State: "running"}, + }, + } + + r := FromKubernetesPod(pod, cluster) + if r.Status != StatusRunning { + t.Fatalf("expected running pod status") + } + if !strings.Contains(r.ID, "cluster-3/pod/") { + t.Fatalf("expected pod ID to include cluster") + } +} + +func TestKubernetesDeploymentStatuses(t *testing.T) { + now := time.Now() + cluster := models.KubernetesCluster{ID: "cluster-1", AgentID: "agent-1", LastSeen: now} + + depStopped := FromKubernetesDeployment(models.KubernetesDeployment{ + Name: "dep-stop", + Namespace: "default", + DesiredReplicas: 0, + }, cluster) + if depStopped.Status != StatusStopped { + t.Fatalf("expected stopped deployment") + } + + depDegraded := FromKubernetesDeployment(models.KubernetesDeployment{ + Name: "dep-degraded", + Namespace: "default", + DesiredReplicas: 3, + AvailableReplicas: 1, + }, cluster) + if depDegraded.Status != StatusDegraded { + t.Fatalf("expected degraded deployment") + } + + depUnknown := FromKubernetesDeployment(models.KubernetesDeployment{ + Name: "dep-unknown", + Namespace: "default", + DesiredReplicas: 2, + AvailableReplicas: 0, + }, cluster) + if depUnknown.Status != StatusUnknown { + t.Fatalf("expected unknown deployment") + } +} + +func TestFromPBSInstanceAndStorage(t *testing.T) { + now := time.Now() + pbs := models.PBSInstance{ + ID: "pbs-1", + Name: "pbs", + Host: "pbs.local", + Status: "online", + ConnectionHealth: "unhealthy", + CPU: 20, + Memory: 50, + MemoryTotal: 100, + MemoryUsed: 50, + Uptime: 10, + LastSeen: now, + } + pr := FromPBSInstance(pbs) + if pr.Status != StatusDegraded { + t.Fatalf("expected degraded status for unhealthy pbs") + } + + pbs2 := pbs + pbs2.ID = "pbs-2" + pbs2.ConnectionHealth = "healthy" + pbs2.Status = "offline" + pr2 := FromPBSInstance(pbs2) + if pr2.Status != StatusOffline { + t.Fatalf("expected offline status for pbs") + } + + storageOnline := FromStorage(models.Storage{ + ID: "storage-1", + Name: "local", + Instance: "pve1", + Node: "node1", + Total: 100, + Used: 50, + Free: 50, + Usage: 50, + Active: true, + Enabled: true, + }) + if storageOnline.Status != StatusOnline { + t.Fatalf("expected online storage") + } + + storageStopped := FromStorage(models.Storage{ + ID: "storage-2", + Name: "local", + Instance: "pve1", + Node: "node1", + Active: true, + Enabled: false, + }) + if storageStopped.Status != StatusStopped { + t.Fatalf("expected stopped storage") + } + + storageOffline := FromStorage(models.Storage{ + ID: "storage-3", + Name: "local", + Instance: "pve1", + Node: "node1", + Active: false, + Enabled: true, + }) + if storageOffline.Status != StatusOffline { + t.Fatalf("expected offline storage") + } +} + +func TestStatusMappings(t *testing.T) { + if mapGuestStatus("running") != StatusRunning { + t.Fatalf("expected running guest status") + } + if mapGuestStatus("stopped") != StatusStopped { + t.Fatalf("expected stopped guest status") + } + if mapGuestStatus("paused") != StatusPaused { + t.Fatalf("expected paused guest status") + } + if mapGuestStatus("unknown") != StatusUnknown { + t.Fatalf("expected unknown guest status") + } + + if mapHostStatus("online") != StatusOnline { + t.Fatalf("expected online host status") + } + if mapHostStatus("offline") != StatusOffline { + t.Fatalf("expected offline host status") + } + if mapHostStatus("degraded") != StatusDegraded { + t.Fatalf("expected degraded host status") + } + if mapHostStatus("unknown") != StatusUnknown { + t.Fatalf("expected unknown host status") + } + + if mapDockerHostStatus("online") != StatusOnline { + t.Fatalf("expected online docker host status") + } + if mapDockerHostStatus("offline") != StatusOffline { + t.Fatalf("expected offline docker host status") + } + if mapDockerHostStatus("unknown") != StatusUnknown { + t.Fatalf("expected unknown docker host status") + } + + if mapDockerContainerStatus("running") != StatusRunning { + t.Fatalf("expected running container status") + } + if mapDockerContainerStatus("exited") != StatusStopped { + t.Fatalf("expected exited container status") + } + if mapDockerContainerStatus("dead") != StatusStopped { + t.Fatalf("expected dead container status") + } + if mapDockerContainerStatus("paused") != StatusPaused { + t.Fatalf("expected paused container status") + } + if mapDockerContainerStatus("restarting") != StatusUnknown { + t.Fatalf("expected restarting container status unknown") + } + if mapDockerContainerStatus("created") != StatusUnknown { + t.Fatalf("expected created container status unknown") + } + if mapDockerContainerStatus("other") != StatusUnknown { + t.Fatalf("expected unknown container status") + } + + if mapKubernetesClusterStatus(" online ") != StatusOnline { + t.Fatalf("expected online k8s cluster status") + } + if mapKubernetesClusterStatus("offline") != StatusOffline { + t.Fatalf("expected offline k8s cluster status") + } + if mapKubernetesClusterStatus("unknown") != StatusUnknown { + t.Fatalf("expected unknown k8s cluster status") + } + + if mapKubernetesPodStatus("running") != StatusRunning { + t.Fatalf("expected running pod status") + } + if mapKubernetesPodStatus("succeeded") != StatusStopped { + t.Fatalf("expected succeeded pod status") + } + if mapKubernetesPodStatus("failed") != StatusStopped { + t.Fatalf("expected failed pod status") + } + if mapKubernetesPodStatus("pending") != StatusUnknown { + t.Fatalf("expected pending pod status") + } + if mapKubernetesPodStatus("unknown") != StatusUnknown { + t.Fatalf("expected unknown pod status") + } + + if mapPBSStatus("online", "healthy") != StatusOnline { + t.Fatalf("expected online PBS status") + } + if mapPBSStatus("offline", "healthy") != StatusOffline { + t.Fatalf("expected offline PBS status") + } + if mapPBSStatus("other", "healthy") != StatusUnknown { + t.Fatalf("expected unknown PBS status") + } + if mapPBSStatus("online", "unhealthy") != StatusDegraded { + t.Fatalf("expected degraded PBS status") + } +} diff --git a/internal/resources/coverage_test.go b/internal/resources/coverage_test.go new file mode 100644 index 000000000..4698157d6 --- /dev/null +++ b/internal/resources/coverage_test.go @@ -0,0 +1,834 @@ +package resources + +import ( + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +func TestResourcePlatformDataAndMetrics(t *testing.T) { + var out struct { + Value string `json:"value"` + } + + r := Resource{Name: "name"} + if err := r.GetPlatformData(&out); err != nil { + t.Fatalf("expected nil error for empty platform data, got %v", err) + } + + if err := r.SetPlatformData(make(chan int)); err == nil { + t.Fatal("expected error for unserializable platform data") + } + + if err := r.SetPlatformData(struct { + Value string `json:"value"` + }{Value: "ok"}); err != nil { + t.Fatalf("unexpected error setting platform data: %v", err) + } + if err := r.GetPlatformData(&out); err != nil { + t.Fatalf("unexpected error getting platform data: %v", err) + } + if out.Value != "ok" { + t.Fatalf("expected platform data value to be ok, got %q", out.Value) + } + + r.PlatformData = []byte("{") + if err := r.GetPlatformData(&out); err == nil { + t.Fatal("expected error for invalid platform data") + } + + r.DisplayName = "display" + if r.EffectiveDisplayName() != "display" { + t.Fatalf("expected display name to be used") + } + r.DisplayName = "" + if r.EffectiveDisplayName() != "name" { + t.Fatalf("expected name fallback") + } + + if r.CPUPercent() != 0 { + t.Fatalf("expected CPUPercent 0 for nil CPU") + } + r.CPU = &MetricValue{Current: 12.5} + if r.CPUPercent() != 12.5 { + t.Fatalf("expected CPUPercent 12.5") + } + + if r.MemoryPercent() != 0 { + t.Fatalf("expected MemoryPercent 0 for nil memory") + } + total := int64(100) + used := int64(25) + r.Memory = &MetricValue{Total: &total, Used: &used} + if r.MemoryPercent() != 25 { + t.Fatalf("expected MemoryPercent 25") + } + r.Memory = &MetricValue{Current: 40} + if r.MemoryPercent() != 40 { + t.Fatalf("expected MemoryPercent 40") + } + + if r.DiskPercent() != 0 { + t.Fatalf("expected DiskPercent 0 for nil disk") + } + totalDisk := int64(200) + usedDisk := int64(50) + r.Disk = &MetricValue{Total: &totalDisk, Used: &usedDisk} + if r.DiskPercent() != 25 { + t.Fatalf("expected DiskPercent 25") + } + r.Disk = &MetricValue{Current: 33} + if r.DiskPercent() != 33 { + t.Fatalf("expected DiskPercent 33") + } + + r.Type = ResourceTypeNode + if !r.IsInfrastructure() { + t.Fatalf("expected node to be infrastructure") + } + if r.IsWorkload() { + t.Fatalf("expected node to not be workload") + } + r.Type = ResourceTypeVM + if r.IsInfrastructure() { + t.Fatalf("expected vm to not be infrastructure") + } + if !r.IsWorkload() { + t.Fatalf("expected vm to be workload") + } +} + +func TestStorePreferredAndSuppressed(t *testing.T) { + store := NewStore() + now := time.Now() + + existing := Resource{ + ID: "agent-1", + Type: ResourceTypeHost, + SourceType: SourceAgent, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "host1", + }, + } + store.Upsert(existing) + + incoming := Resource{ + ID: "api-1", + Type: ResourceTypeHost, + SourceType: SourceAPI, + LastSeen: now.Add(1 * time.Minute), + Identity: &ResourceIdentity{ + Hostname: "host1", + }, + } + preferred := store.Upsert(incoming) + if preferred != existing.ID { + t.Fatalf("expected preferred ID to be %s, got %s", existing.ID, preferred) + } + if !store.IsSuppressed(incoming.ID) { + t.Fatalf("expected incoming to be suppressed") + } + if store.GetPreferredID(incoming.ID) != existing.ID { + t.Fatalf("expected preferred ID to map to %s", existing.ID) + } + if store.GetPreferredID(existing.ID) != existing.ID { + t.Fatalf("expected preferred ID to return itself") + } + + got, ok := store.Get(incoming.ID) + if !ok || got.ID != existing.ID { + t.Fatalf("expected Get to return preferred resource") + } + + if store.GetPreferredResourceFor(incoming.ID) == nil { + t.Fatalf("expected preferred resource for suppressed ID") + } + if store.GetPreferredResourceFor(existing.ID) == nil { + t.Fatalf("expected preferred resource for existing ID") + } + if store.GetPreferredResourceFor("missing") != nil { + t.Fatalf("expected nil for missing resource") + } + + if !store.IsSamePhysicalMachine(existing.ID, incoming.ID) { + t.Fatalf("expected IDs to be same physical machine") + } + if !store.IsSamePhysicalMachine(incoming.ID, existing.ID) { + t.Fatalf("expected merged ID to match preferred") + } + if store.IsSamePhysicalMachine(existing.ID, "other") { + t.Fatalf("expected different IDs to not match") + } + if !store.IsSamePhysicalMachine(existing.ID, existing.ID) { + t.Fatalf("expected same ID to match") + } + + store.Remove(existing.ID) + if store.IsSuppressed(incoming.ID) { + t.Fatalf("expected suppression to be cleared after removal") + } + if _, ok := store.Get(incoming.ID); ok { + t.Fatalf("expected Get to fail for removed preferred resource") + } +} + +func TestStoreStatsAndHelpers(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "host-1", + Type: ResourceTypeHost, + Status: StatusOffline, + SourceType: SourceAgent, + LastSeen: now, + Alerts: []ResourceAlert{{ID: "a1"}}, + Identity: &ResourceIdentity{ + Hostname: "host-1", + }, + }) + store.Upsert(Resource{ + ID: "host-2", + Type: ResourceTypeHost, + Status: StatusOnline, + SourceType: SourceAPI, + LastSeen: now.Add(time.Minute), + Identity: &ResourceIdentity{ + Hostname: "host-1", + }, + }) + + stats := store.GetStats() + if stats.SuppressedResources != 1 { + t.Fatalf("expected 1 suppressed resource") + } + if stats.WithAlerts != 1 { + t.Fatalf("expected 1 resource with alerts") + } + + if store.sourceScore(SourceType("other")) != 0 { + t.Fatalf("expected default source score") + } + + a := &Resource{ID: "a", SourceType: SourceAPI, LastSeen: now} + b := &Resource{ID: "b", SourceType: SourceAgent, LastSeen: now.Add(time.Second)} + if store.preferredResource(a, b) != b { + t.Fatalf("expected agent resource to be preferred") + } + c := &Resource{ID: "c", SourceType: SourceAPI, LastSeen: now.Add(time.Second)} + if store.preferredResource(a, c) != c { + t.Fatalf("expected newer resource to be preferred") + } + d := &Resource{ID: "d", SourceType: SourceAgent, LastSeen: now} + e := &Resource{ID: "e", SourceType: SourceAPI, LastSeen: now} + if store.preferredResource(d, e) != d { + t.Fatalf("expected higher score resource to be preferred") + } + f := &Resource{ID: "f", SourceType: SourceAPI, LastSeen: now.Add(2 * time.Second)} + g := &Resource{ID: "g", SourceType: SourceAPI, LastSeen: now.Add(1 * time.Second)} + if store.preferredResource(f, g) != f { + t.Fatalf("expected newer resource to be preferred when scores equal") + } + + if store.findDuplicate(&Resource{}) != "" { + t.Fatalf("expected no duplicate for nil identity") + } + if store.findDuplicate(&Resource{Type: ResourceTypeVM, Identity: &ResourceIdentity{Hostname: "host-1"}}) != "" { + t.Fatalf("expected no duplicate for workload type") + } +} + +func TestStorePreferredSourceAndPolling(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "api-1", + Type: ResourceTypeHost, + SourceType: SourceAPI, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "host1", + }, + }) + if store.HasPreferredSourceForHostname("host1") { + t.Fatalf("expected no preferred source for API-only hostname") + } + + store.Upsert(Resource{ + ID: "hybrid-1", + Type: ResourceTypeHost, + SourceType: SourceHybrid, + LastSeen: now.Add(time.Second), + Identity: &ResourceIdentity{ + Hostname: "host1", + }, + }) + if !store.HasPreferredSourceForHostname("HOST1") { + t.Fatalf("expected preferred source for hostname") + } + if store.HasPreferredSourceForHostname("") { + t.Fatalf("expected empty hostname to be false") + } + if !store.ShouldSkipAPIPolling("host1") { + t.Fatalf("expected skip polling for preferred source") + } + if store.HasPreferredSourceForHostname("missing") { + t.Fatalf("expected missing hostname to be false") + } +} + +func TestStoreAgentHostnamesAndRecommendations(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "agent-1", + Type: ResourceTypeHost, + SourceType: SourceAgent, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "host1", + }, + }) + store.Upsert(Resource{ + ID: "hybrid-1", + Type: ResourceTypeHost, + SourceType: SourceHybrid, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "HOST2", + }, + }) + store.Upsert(Resource{ + ID: "api-1", + Type: ResourceTypeHost, + SourceType: SourceAPI, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "host3", + }, + }) + store.Upsert(Resource{ + ID: "no-identity", + Type: ResourceTypeHost, + SourceType: SourceAgent, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "node-agent", + Type: ResourceTypeNode, + SourceType: SourceAgent, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "host1", + }, + }) + + hostnames := store.GetAgentMonitoredHostnames() + seen := make(map[string]bool) + for _, h := range hostnames { + seen[strings.ToLower(h)] = true + } + if !seen["host1"] || !seen["host2"] { + t.Fatalf("expected host1 and host2 to be monitored, got %v", hostnames) + } + if len(seen) != 2 { + t.Fatalf("expected two unique hostnames") + } + + recs := store.GetPollingRecommendations() + if recs["host1"] != 0 { + t.Fatalf("expected host1 recommendation 0, got %v", recs["host1"]) + } + if recs["host2"] != 0.5 { + t.Fatalf("expected host2 recommendation 0.5, got %v", recs["host2"]) + } + if _, ok := recs["host3"]; ok { + t.Fatalf("did not expect API-only host to have recommendation") + } +} + +func TestStoreFindContainerHost(t *testing.T) { + store := NewStore() + now := time.Now() + + host := Resource{ + ID: "docker-host-1", + Type: ResourceTypeDockerHost, + Name: "docker-host", + SourceType: SourceAgent, + LastSeen: now, + Identity: &ResourceIdentity{ + Hostname: "docker1", + }, + } + store.Upsert(host) + store.Upsert(Resource{ + ID: "docker-host-1/container-1", + Type: ResourceTypeDockerContainer, + Name: "web-app", + ParentID: "docker-host-1", + SourceType: SourceAgent, + LastSeen: now, + }) + + if store.FindContainerHost("") != "" { + t.Fatalf("expected empty search to return empty") + } + if store.FindContainerHost("missing") != "" { + t.Fatalf("expected missing container to return empty") + } + if store.FindContainerHost("web-app") != "docker1" { + t.Fatalf("expected host name from identity") + } + if store.FindContainerHost("CONTAINER-1") != "docker1" { + t.Fatalf("expected match by ID") + } + if store.FindContainerHost("web") != "docker1" { + t.Fatalf("expected match by substring") + } + + store2 := NewStore() + store2.Upsert(Resource{ + ID: "container-2", + Type: ResourceTypeDockerContainer, + Name: "db", + ParentID: "missing-host", + SourceType: SourceAgent, + LastSeen: now, + }) + if store2.FindContainerHost("db") != "" { + t.Fatalf("expected missing parent to return empty") + } + + store3 := NewStore() + store3.Upsert(Resource{ + ID: "host-no-identity", + Type: ResourceTypeDockerHost, + Name: "host-name", + SourceType: SourceAgent, + LastSeen: now, + }) + store3.Upsert(Resource{ + ID: "host-no-identity/container-3", + Type: ResourceTypeDockerContainer, + Name: "cache", + ParentID: "host-no-identity", + SourceType: SourceAgent, + LastSeen: now, + }) + if store3.FindContainerHost("cache") != "host-name" { + t.Fatalf("expected fallback to host name") + } +} + +func TestResourceQueryFiltersAndSorting(t *testing.T) { + store := NewStore() + base := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + + store.Upsert(Resource{ + ID: "r1", + Type: ResourceTypeVM, + Name: "b", + Status: StatusRunning, + ClusterID: "c1", + CPU: &MetricValue{Current: 10}, + Memory: &MetricValue{Current: 60}, + Disk: &MetricValue{Current: 20}, + LastSeen: base.Add(1 * time.Hour), + Alerts: []ResourceAlert{{ID: "a1"}}, + SourceType: SourceAPI, + }) + store.Upsert(Resource{ + ID: "r2", + Type: ResourceTypeNode, + Name: "a", + Status: StatusOnline, + ClusterID: "c1", + CPU: &MetricValue{Current: 50}, + Memory: &MetricValue{Current: 10}, + Disk: &MetricValue{Current: 90}, + LastSeen: base.Add(2 * time.Hour), + SourceType: SourceAPI, + }) + store.Upsert(Resource{ + ID: "r3", + Type: ResourceTypeVM, + Name: "c", + Status: StatusOffline, + ClusterID: "c2", + CPU: &MetricValue{Current: 5}, + Memory: &MetricValue{Current: 30}, + Disk: &MetricValue{Current: 10}, + LastSeen: base.Add(3 * time.Hour), + SourceType: SourceAPI, + }) + + clustered := store.Query().InCluster("c1").Execute() + if len(clustered) != 2 { + t.Fatalf("expected 2 clustered resources, got %d", len(clustered)) + } + + withAlerts := store.Query().WithAlerts().Execute() + if len(withAlerts) != 1 || withAlerts[0].ID != "r1" { + t.Fatalf("expected only r1 with alerts") + } + + sorted := store.Query().SortBy("name", false).Execute() + if len(sorted) < 2 || sorted[0].Name != "a" { + t.Fatalf("expected sorted results by name") + } + + limited := store.Query().SortBy("cpu", true).Offset(1).Limit(1).Execute() + if len(limited) != 1 { + t.Fatalf("expected limited results") + } + + empty := store.Query().Offset(10).Execute() + if len(empty) != 0 { + t.Fatalf("expected empty results for large offset") + } +} + +func TestSortResourcesFields(t *testing.T) { + base := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + resources := []Resource{ + { + ID: "r1", + Type: ResourceTypeVM, + Name: "b", + Status: StatusRunning, + CPU: &MetricValue{Current: 20}, + Memory: &MetricValue{Current: 40}, + Disk: &MetricValue{Current: 30}, + LastSeen: base.Add(1 * time.Hour), + }, + { + ID: "r2", + Type: ResourceTypeNode, + Name: "a", + Status: StatusOffline, + CPU: &MetricValue{Current: 50}, + Memory: &MetricValue{Current: 10}, + Disk: &MetricValue{Current: 90}, + LastSeen: base.Add(2 * time.Hour), + }, + { + ID: "r3", + Type: ResourceTypeContainer, + Name: "c", + Status: StatusDegraded, + CPU: &MetricValue{Current: 5}, + Memory: &MetricValue{Current: 80}, + Disk: &MetricValue{Current: 10}, + LastSeen: base.Add(3 * time.Hour), + }, + } + + single := []Resource{{ID: "only"}} + sortResources(single, "name", false) + + cases := []struct { + field string + desc bool + want string + }{ + {"name", false, "r2"}, + {"name", true, "r3"}, + {"type", false, "r3"}, + {"type", true, "r1"}, + {"status", false, "r3"}, + {"status", true, "r1"}, + {"cpu", true, "r2"}, + {"cpu", false, "r3"}, + {"memory", true, "r3"}, + {"memory", false, "r2"}, + {"disk", true, "r2"}, + {"disk", false, "r3"}, + {"last_seen", true, "r3"}, + {"last_seen", false, "r1"}, + {"lastseen", true, "r3"}, + {"mem", true, "r3"}, + } + + for _, tc := range cases { + sorted := append([]Resource(nil), resources...) + sortResources(sorted, tc.field, tc.desc) + if sorted[0].ID != tc.want { + t.Fatalf("sort %s desc=%v expected %s, got %s", tc.field, tc.desc, tc.want, sorted[0].ID) + } + } +} + +func TestGetTopByMemoryAndDisk(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "vm1", + Type: ResourceTypeVM, + Memory: &MetricValue{Current: 80}, + Disk: &MetricValue{Current: 30}, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "vm2", + Type: ResourceTypeVM, + Memory: &MetricValue{Current: 20}, + Disk: &MetricValue{Current: 90}, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "node1", + Type: ResourceTypeNode, + Memory: &MetricValue{Current: 60}, + Disk: &MetricValue{Current: 10}, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "skip-memory", + Type: ResourceTypeVM, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "skip-disk", + Type: ResourceTypeVM, + Disk: &MetricValue{Current: 0}, + LastSeen: now, + }) + + topMem := store.GetTopByMemory(1, nil) + if len(topMem) != 1 || topMem[0].ID != "vm1" { + t.Fatalf("expected vm1 to be top memory") + } + topMemVMs := store.GetTopByMemory(10, []ResourceType{ResourceTypeVM}) + if len(topMemVMs) != 2 { + t.Fatalf("expected 2 VM memory results, got %d", len(topMemVMs)) + } + + topDisk := store.GetTopByDisk(1, nil) + if len(topDisk) != 1 || topDisk[0].ID != "vm2" { + t.Fatalf("expected vm2 to be top disk") + } + topDiskNodes := store.GetTopByDisk(10, []ResourceType{ResourceTypeNode}) + if len(topDiskNodes) != 1 || topDiskNodes[0].ID != "node1" { + t.Fatalf("expected node1 to be top disk node") + } +} + +func TestGetTopByCPU_SkipsZero(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "skip-cpu-1", + Type: ResourceTypeVM, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "skip-cpu-2", + Type: ResourceTypeVM, + CPU: &MetricValue{Current: 0}, + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "cpu-1", + Type: ResourceTypeVM, + CPU: &MetricValue{Current: 10}, + LastSeen: now, + }) + + top := store.GetTopByCPU(10, nil) + if len(top) != 1 || top[0].ID != "cpu-1" { + t.Fatalf("expected only cpu-1 to be returned") + } +} + +func TestGetRelatedWithChildren(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "parent", + Type: ResourceTypeNode, + Name: "parent", + LastSeen: now, + }) + store.Upsert(Resource{ + ID: "child-1", + Type: ResourceTypeVM, + Name: "child-1", + ParentID: "parent", + LastSeen: now, + }) + + related := store.GetRelated("parent") + if children, ok := related["children"]; !ok || len(children) != 1 { + t.Fatalf("expected children for parent") + } + + if len(store.GetRelated("missing")) != 0 { + t.Fatalf("expected no related resources for missing ID") + } +} + +func TestResourceSummaryWithDegradedAndAlerts(t *testing.T) { + store := NewStore() + now := time.Now() + + store.Upsert(Resource{ + ID: "healthy", + Type: ResourceTypeNode, + Status: StatusOnline, + LastSeen: now, + CPU: &MetricValue{Current: 20}, + Memory: &MetricValue{Current: 50}, + }) + store.Upsert(Resource{ + ID: "degraded", + Type: ResourceTypeVM, + Status: StatusDegraded, + LastSeen: now, + Alerts: []ResourceAlert{{ID: "a1"}}, + }) + store.Upsert(Resource{ + ID: "unknown", + Type: ResourceTypeVM, + Status: StatusUnknown, + LastSeen: now, + }) + + summary := store.GetResourceSummary() + if summary.Degraded != 1 { + t.Fatalf("expected degraded count 1") + } + if summary.Offline != 1 { + t.Fatalf("expected offline count 1") + } + if summary.WithAlerts != 1 { + t.Fatalf("expected alerts count 1") + } +} + +func TestPopulateFromSnapshotFull(t *testing.T) { + store := NewStore() + store.Upsert(Resource{ID: "old-resource", Type: ResourceTypeNode, LastSeen: time.Now()}) + + now := time.Now() + cluster := models.KubernetesCluster{ + ID: "cluster-1", + AgentID: "agent-1", + Status: "online", + LastSeen: now, + Nodes: []models.KubernetesNode{ + {Name: "node-1", Ready: true}, + }, + Pods: []models.KubernetesPod{ + {Name: "pod-1", Namespace: "default", Phase: "Running", NodeName: "node-1"}, + }, + Deployments: []models.KubernetesDeployment{ + {Name: "dep-1", Namespace: "default", DesiredReplicas: 1, AvailableReplicas: 1}, + }, + } + dockerHost := models.DockerHost{ + ID: "docker-1", + AgentID: "agent-docker", + Hostname: "docker-host", + Status: "online", + Memory: models.Memory{ + Total: 100, + Used: 50, + Free: 50, + Usage: 50, + }, + Disks: []models.Disk{ + {Total: 100, Used: 60, Free: 40, Usage: 60}, + }, + NetworkInterfaces: []models.HostNetworkInterface{ + {Addresses: []string{"10.0.0.1"}, RXBytes: 1, TXBytes: 2}, + }, + Containers: []models.DockerContainer{ + { + ID: "container-1", + Name: "web", + State: "running", + Status: "Up", + CPUPercent: 5, + MemoryUsage: 50, + MemoryLimit: 100, + MemoryPercent: 50, + }, + }, + LastSeen: now, + } + snapshot := models.StateSnapshot{ + DockerHosts: []models.DockerHost{dockerHost}, + KubernetesClusters: []models.KubernetesCluster{ + cluster, + }, + PBSInstances: []models.PBSInstance{ + { + ID: "pbs-1", + Name: "pbs", + Host: "pbs.local", + Status: "online", + ConnectionHealth: "healthy", + CPU: 20, + Memory: 50, + MemoryTotal: 100, + MemoryUsed: 50, + Uptime: 10, + LastSeen: now, + }, + }, + Storage: []models.Storage{ + { + ID: "storage-1", + Name: "local", + Instance: "pve1", + Node: "node1", + Total: 100, + Used: 50, + Free: 50, + Usage: 50, + Active: true, + Enabled: true, + }, + }, + } + + store.PopulateFromSnapshot(snapshot) + + if _, ok := store.Get("old-resource"); ok { + t.Fatalf("expected old resource to be removed") + } + + if len(store.Query().OfType(ResourceTypeDockerHost).Execute()) != 1 { + t.Fatalf("expected 1 docker host") + } + if len(store.Query().OfType(ResourceTypeDockerContainer).Execute()) != 1 { + t.Fatalf("expected 1 docker container") + } + if len(store.Query().OfType(ResourceTypeK8sCluster).Execute()) != 1 { + t.Fatalf("expected 1 k8s cluster") + } + if len(store.Query().OfType(ResourceTypeK8sNode).Execute()) != 1 { + t.Fatalf("expected 1 k8s node") + } + if len(store.Query().OfType(ResourceTypePod).Execute()) != 1 { + t.Fatalf("expected 1 k8s pod") + } + if len(store.Query().OfType(ResourceTypeK8sDeployment).Execute()) != 1 { + t.Fatalf("expected 1 k8s deployment") + } + if len(store.Query().OfType(ResourceTypePBS).Execute()) != 1 { + t.Fatalf("expected 1 PBS instance") + } + if len(store.Query().OfType(ResourceTypeStorage).Execute()) != 1 { + t.Fatalf("expected 1 storage resource") + } +} diff --git a/internal/smartctl/collector.go b/internal/smartctl/collector.go index 3828c566b..047a8d94d 100644 --- a/internal/smartctl/collector.go +++ b/internal/smartctl/collector.go @@ -13,17 +13,25 @@ import ( "github.com/rs/zerolog/log" ) +var ( + execLookPath = exec.LookPath + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + return exec.CommandContext(ctx, name, args...).Output() + } + timeNow = time.Now +) + // DiskSMART represents S.M.A.R.T. data for a single disk. type DiskSMART struct { - Device string `json:"device"` // Device path (e.g., /dev/sda) - Model string `json:"model,omitempty"` // Disk model - Serial string `json:"serial,omitempty"` // Serial number - WWN string `json:"wwn,omitempty"` // World Wide Name - Type string `json:"type,omitempty"` // Transport type: sata, sas, nvme - Temperature int `json:"temperature"` // Temperature in Celsius - Health string `json:"health,omitempty"` // PASSED, FAILED, UNKNOWN - Standby bool `json:"standby,omitempty"` // True if disk was in standby - LastUpdated time.Time `json:"lastUpdated"` // When this reading was taken + Device string `json:"device"` // Device path (e.g., /dev/sda) + Model string `json:"model,omitempty"` // Disk model + Serial string `json:"serial,omitempty"` // Serial number + WWN string `json:"wwn,omitempty"` // World Wide Name + Type string `json:"type,omitempty"` // Transport type: sata, sas, nvme + Temperature int `json:"temperature"` // Temperature in Celsius + Health string `json:"health,omitempty"` // PASSED, FAILED, UNKNOWN + Standby bool `json:"standby,omitempty"` // True if disk was in standby + LastUpdated time.Time `json:"lastUpdated"` // When this reading was taken } // smartctlJSON represents the JSON output from smartctl --json. @@ -85,8 +93,7 @@ func CollectLocal(ctx context.Context) ([]DiskSMART, error) { // listBlockDevices returns a list of block devices suitable for SMART queries. func listBlockDevices(ctx context.Context) ([]string, error) { // Use lsblk to find disks (not partitions) - cmd := exec.CommandContext(ctx, "lsblk", "-d", "-n", "-o", "NAME,TYPE") - output, err := cmd.Output() + output, err := runCommandOutput(ctx, "lsblk", "-d", "-n", "-o", "NAME,TYPE") if err != nil { return nil, err } @@ -114,7 +121,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) defer cancel() // Check if smartctl is available - smartctlPath, err := exec.LookPath("smartctl") + smartctlPath, err := execLookPath("smartctl") if err != nil { return nil, err } @@ -124,8 +131,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) // -i: device info // -A: attributes (for temperature) // --json=o: output original smartctl JSON format - cmd := exec.CommandContext(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device) - output, err := cmd.Output() + output, err := runCommandOutput(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device) // smartctl returns non-zero exit codes for various conditions // Exit code 2 means drive is in standby - that's okay @@ -137,7 +143,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) return &DiskSMART{ Device: filepath.Base(device), Standby: true, - LastUpdated: time.Now(), + LastUpdated: timeNow(), }, nil } // Other exit codes might still have valid JSON output @@ -161,7 +167,7 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) Model: smartData.ModelName, Serial: smartData.SerialNumber, Type: detectDiskType(smartData), - LastUpdated: time.Now(), + LastUpdated: timeNow(), } // Build WWN string if available diff --git a/internal/smartctl/collector_coverage_test.go b/internal/smartctl/collector_coverage_test.go new file mode 100644 index 000000000..6cdd90298 --- /dev/null +++ b/internal/smartctl/collector_coverage_test.go @@ -0,0 +1,303 @@ +package smartctl + +import ( + "context" + "encoding/json" + "errors" + "os/exec" + "strings" + "testing" + "time" +) + +func TestListBlockDevices(t *testing.T) { + origRun := runCommandOutput + t.Cleanup(func() { runCommandOutput = origRun }) + + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name != "lsblk" { + return nil, errors.New("unexpected command") + } + out := "sda disk\nsda1 part\nsr0 rom\nnvme0n1 disk\n\n" + return []byte(out), nil + } + + devices, err := listBlockDevices(context.Background()) + if err != nil { + t.Fatalf("listBlockDevices error: %v", err) + } + if len(devices) != 2 || devices[0] != "/dev/sda" || devices[1] != "/dev/nvme0n1" { + t.Fatalf("unexpected devices: %#v", devices) + } +} + +func TestListBlockDevicesError(t *testing.T) { + origRun := runCommandOutput + t.Cleanup(func() { runCommandOutput = origRun }) + + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("boom") + } + + if _, err := listBlockDevices(context.Background()); err == nil { + t.Fatalf("expected error") + } +} + +func TestDefaultRunCommandOutput(t *testing.T) { + origRun := runCommandOutput + t.Cleanup(func() { runCommandOutput = origRun }) + runCommandOutput = origRun + + if _, err := runCommandOutput(context.Background(), "sh", "-c", "printf ''"); err != nil { + t.Fatalf("expected default runCommandOutput to work, got %v", err) + } +} + +func TestCollectLocalNoDevices(t *testing.T) { + origRun := runCommandOutput + t.Cleanup(func() { runCommandOutput = origRun }) + + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name == "lsblk" { + return []byte(""), nil + } + return nil, errors.New("unexpected command") + } + + result, err := CollectLocal(context.Background()) + if err != nil { + t.Fatalf("CollectLocal error: %v", err) + } + if result != nil { + t.Fatalf("expected nil result for no devices") + } +} + +func TestCollectLocalListDevicesError(t *testing.T) { + origRun := runCommandOutput + t.Cleanup(func() { runCommandOutput = origRun }) + + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + return nil, errors.New("lsblk failed") + } + + if _, err := CollectLocal(context.Background()); err == nil { + t.Fatalf("expected list error") + } +} + +func TestCollectLocalSkipsErrors(t *testing.T) { + origRun := runCommandOutput + origLook := execLookPath + origNow := timeNow + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + timeNow = origNow + }) + + fixed := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + timeNow = func() time.Time { return fixed } + execLookPath = func(string) (string, error) { return "smartctl", nil } + + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name == "lsblk" { + return []byte("sda disk\nsdb disk\n"), nil + } + if name == "smartctl" { + device := args[len(args)-1] + if strings.Contains(device, "sda") { + return nil, errors.New("read error") + } + payload := smartctlJSON{ + ModelName: "Model", + SerialNumber: "Serial", + } + payload.Device.Protocol = "ATA" + payload.SmartStatus.Passed = true + payload.Temperature.Current = 30 + out, _ := json.Marshal(payload) + return out, nil + } + return nil, errors.New("unexpected command") + } + + result, err := CollectLocal(context.Background()) + if err != nil { + t.Fatalf("CollectLocal error: %v", err) + } + if len(result) != 1 { + t.Fatalf("expected 1 result, got %d", len(result)) + } + if result[0].Device != "sdb" || !result[0].LastUpdated.Equal(fixed) { + t.Fatalf("unexpected result: %#v", result[0]) + } +} + +func TestCollectDeviceSMARTLookPathError(t *testing.T) { + origLook := execLookPath + t.Cleanup(func() { execLookPath = origLook }) + execLookPath = func(string) (string, error) { return "", errors.New("missing") } + + if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil { + t.Fatalf("expected lookpath error") + } +} + +func TestCollectDeviceSMARTStandby(t *testing.T) { + if _, err := exec.LookPath("sh"); err != nil { + t.Skip("sh not available") + } + + origRun := runCommandOutput + origLook := execLookPath + origNow := timeNow + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + timeNow = origNow + }) + + fixed := time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC) + timeNow = func() time.Time { return fixed } + execLookPath = func(string) (string, error) { return "smartctl", nil } + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + return exec.CommandContext(ctx, "sh", "-c", "exit 2").Output() + } + + result, err := collectDeviceSMART(context.Background(), "/dev/sda") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || !result.Standby || result.Device != "sda" || !result.LastUpdated.Equal(fixed) { + t.Fatalf("unexpected standby result: %#v", result) + } +} + +func TestCollectDeviceSMARTExitErrorNoOutput(t *testing.T) { + if _, err := exec.LookPath("sh"); err != nil { + t.Skip("sh not available") + } + + origRun := runCommandOutput + origLook := execLookPath + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + }) + + execLookPath = func(string) (string, error) { return "smartctl", nil } + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + return exec.CommandContext(ctx, "sh", "-c", "exit 1").Output() + } + + if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil { + t.Fatalf("expected error for exit code without output") + } +} + +func TestCollectDeviceSMARTExitErrorWithOutput(t *testing.T) { + if _, err := exec.LookPath("sh"); err != nil { + t.Skip("sh not available") + } + + origRun := runCommandOutput + origLook := execLookPath + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + }) + + execLookPath = func(string) (string, error) { return "smartctl", nil } + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + payload := `{"model_name":"Model","serial_number":"Serial","device":{"protocol":"ATA"},"smart_status":{"passed":false},"temperature":{"current":45}}` + return exec.CommandContext(ctx, "sh", "-c", "echo '"+payload+"'; exit 1").Output() + } + + result, err := collectDeviceSMART(context.Background(), "/dev/sda") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || result.Health != "FAILED" || result.Temperature != 45 { + t.Fatalf("unexpected result: %#v", result) + } +} + +func TestCollectDeviceSMARTJSONError(t *testing.T) { + origRun := runCommandOutput + origLook := execLookPath + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + }) + + execLookPath = func(string) (string, error) { return "smartctl", nil } + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte("{"), nil + } + + if _, err := collectDeviceSMART(context.Background(), "/dev/sda"); err == nil { + t.Fatalf("expected json error") + } +} + +func TestCollectDeviceSMARTNVMeTempFallback(t *testing.T) { + origRun := runCommandOutput + origLook := execLookPath + origNow := timeNow + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + timeNow = origNow + }) + + fixed := time.Date(2024, 4, 5, 6, 7, 8, 0, time.UTC) + timeNow = func() time.Time { return fixed } + execLookPath = func(string) (string, error) { return "smartctl", nil } + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + payload := smartctlJSON{} + payload.Device.Protocol = "NVMe" + payload.NVMeSmartHealthInformationLog.Temperature = 55 + payload.SmartStatus.Passed = true + out, _ := json.Marshal(payload) + return out, nil + } + + result, err := collectDeviceSMART(context.Background(), "/dev/nvme0n1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Temperature != 55 || result.Type != "nvme" || result.Health != "PASSED" || !result.LastUpdated.Equal(fixed) { + t.Fatalf("unexpected result: %#v", result) + } +} + +func TestCollectDeviceSMARTWWN(t *testing.T) { + origRun := runCommandOutput + origLook := execLookPath + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + }) + + execLookPath = func(string) (string, error) { return "smartctl", nil } + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + payload := smartctlJSON{} + payload.WWN.NAA = 5 + payload.WWN.OUI = 0xabc + payload.WWN.ID = 0x1234 + payload.Device.Protocol = "SAS" + payload.SmartStatus.Passed = true + out, _ := json.Marshal(payload) + return out, nil + } + + result, err := collectDeviceSMART(context.Background(), "/dev/sda") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.WWN != "5-abc-1234" { + t.Fatalf("unexpected WWN: %q", result.WWN) + } +} diff --git a/internal/ssh/knownhosts/manager.go b/internal/ssh/knownhosts/manager.go index 5d654ecd7..04c91fa25 100644 --- a/internal/ssh/knownhosts/manager.go +++ b/internal/ssh/knownhosts/manager.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -44,12 +45,33 @@ const ( ) var ( + mkdirAllFn = os.MkdirAll + statFn = os.Stat + openFileFn = os.OpenFile + openFn = os.Open + appendOpenFileFn = func(path string) (io.WriteCloser, error) { + return openFileFn(path, os.O_APPEND|os.O_WRONLY, 0o600) + } + keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "ssh-keyscan", args...) + return cmd.CombinedOutput() + } + // ErrNoHostKeys is returned when ssh-keyscan yields no usable entries. ErrNoHostKeys = errors.New("knownhosts: no host keys discovered") // ErrHostKeyChanged signals that a host key already exists with a different fingerprint. ErrHostKeyChanged = errors.New("knownhosts: host key changed") ) +var ( + defaultMkdirAllFn = mkdirAllFn + defaultStatFn = statFn + defaultOpenFileFn = openFileFn + defaultOpenFn = openFn + defaultAppendOpenFileFn = appendOpenFileFn + defaultKeyscanCmdRunner = keyscanCmdRunner +) + // HostKeyChangeError describes a detected host key mismatch. type HostKeyChangeError struct { Host string @@ -214,17 +236,17 @@ func (m *manager) Path() string { func (m *manager) ensureKnownHostsFile() error { dir := filepath.Dir(m.path) - if err := os.MkdirAll(dir, 0o700); err != nil { + if err := mkdirAllFn(dir, 0o700); err != nil { return fmt.Errorf("knownhosts: mkdir %s: %w", dir, err) } - if _, err := os.Stat(m.path); err == nil { + if _, err := statFn(m.path); err == nil { return nil } else if !os.IsNotExist(err) { return err } - f, err := os.OpenFile(m.path, os.O_CREATE|os.O_WRONLY, 0o600) + f, err := openFileFn(m.path, os.O_CREATE|os.O_WRONLY, 0o600) if err != nil { return fmt.Errorf("knownhosts: create %s: %w", m.path, err) } @@ -232,7 +254,7 @@ func (m *manager) ensureKnownHostsFile() error { } func appendHostKey(path string, entries [][]byte) error { - f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0o600) + f, err := appendOpenFileFn(path) if err != nil { return fmt.Errorf("knownhosts: open %s: %w", path, err) } @@ -287,7 +309,7 @@ func normalizeHostEntry(host string, entry []byte) ([]byte, string, error) { } func findHostKeyLine(path, host, keyType string) (string, error) { - f, err := os.Open(path) + f, err := openFn(path) if err != nil { if os.IsNotExist(err) { return "", nil @@ -328,10 +350,6 @@ func hostLineMatches(host, line string) bool { } fields := strings.Fields(trimmed) - if len(fields) == 0 { - return false - } - return hostFieldMatches(host, fields[0]) } @@ -396,8 +414,7 @@ func defaultKeyscan(ctx context.Context, host string, port int, timeout time.Dur } args = append(args, host) - cmd := exec.CommandContext(scanCtx, "ssh-keyscan", args...) - output, err := cmd.CombinedOutput() + output, err := keyscanCmdRunner(scanCtx, args...) if err != nil { return nil, fmt.Errorf("%w (output: %s)", err, strings.TrimSpace(string(output))) } diff --git a/internal/ssh/knownhosts/manager_test.go b/internal/ssh/knownhosts/manager_test.go index 5e407c211..fd6f5762d 100644 --- a/internal/ssh/knownhosts/manager_test.go +++ b/internal/ssh/knownhosts/manager_test.go @@ -1,15 +1,27 @@ package knownhosts import ( + "bytes" "context" "errors" + "io" "os" "path/filepath" + "runtime" "strings" "testing" "time" ) +func resetKnownHostsFns() { + mkdirAllFn = defaultMkdirAllFn + statFn = defaultStatFn + openFileFn = defaultOpenFileFn + openFn = defaultOpenFn + appendOpenFileFn = defaultAppendOpenFileFn + keyscanCmdRunner = defaultKeyscanCmdRunner +} + func TestEnsureCreatesFileAndCaches(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "known_hosts") @@ -110,6 +122,167 @@ func TestEnsureRespectsContextCancellation(t *testing.T) { } } +func TestNewManagerEmptyPath(t *testing.T) { + if _, err := NewManager(""); err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestEnsureWithPortMissingHost(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + mgr, err := NewManager(path) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithPort(context.Background(), "", 22); err == nil { + t.Fatal("expected error for missing host") + } +} + +func TestEnsureWithPortDefaultsPort(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + var gotPort int + keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + gotPort = port + return []byte(host + " ssh-ed25519 AAAA"), nil + } + + mgr, err := NewManager(path, WithKeyscanFunc(keyscan)) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithPort(context.Background(), "example.com", 0); err != nil { + t.Fatalf("EnsureWithPort: %v", err) + } + if gotPort != 22 { + t.Fatalf("expected port 22, got %d", gotPort) + } +} + +func TestEnsureWithPortCustomPort(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + return []byte("[example.com]:2222 ssh-ed25519 AAAA"), nil + } + + mgr, err := NewManager(path, WithKeyscanFunc(keyscan)) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithPort(context.Background(), "example.com", 2222); err != nil { + t.Fatalf("EnsureWithPort: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if got := strings.TrimSpace(string(data)); got != "[example.com]:2222 ssh-ed25519 AAAA" { + t.Fatalf("unexpected known_hosts contents: %s", got) + } +} + +func TestEnsureWithPortKeyscanError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) { + return nil, errors.New("scan failed") + } + + mgr, err := NewManager(path, WithKeyscanFunc(keyscan)) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithPort(context.Background(), "example.com", 22); err == nil { + t.Fatal("expected keyscan error") + } +} + +func TestEnsureWithEntriesMissingHost(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + mgr, err := NewManager(path) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil { + t.Fatal("expected missing host error") + } +} + +func TestEnsureWithEntriesNoEntries(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + mgr, err := NewManager(path) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, nil); err == nil { + t.Fatal("expected no entries error") + } +} + +func TestEnsureWithEntriesNormalizeError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + mgr, err := NewManager(path) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("invalid")}); err == nil { + t.Fatal("expected normalize error") + } +} + +func TestEnsureWithEntriesEnsureKnownHostsFileError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + mkdirAllFn = func(string, os.FileMode) error { + return errors.New("mkdir failed") + } + + mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts")) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil { + t.Fatal("expected ensureKnownHostsFile error") + } +} + +func TestEnsureWithEntriesFindHostKeyLineError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + openFn = func(string) (*os.File, error) { + return nil, errors.New("open failed") + } + + mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts")) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil { + t.Fatal("expected open error") + } +} + func TestHostCandidates(t *testing.T) { tests := []struct { input string @@ -293,3 +466,312 @@ func TestHostFieldMatches(t *testing.T) { }) } } + +func TestEnsureKnownHostsFileMkdirError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + mkdirAllFn = func(string, os.FileMode) error { + return errors.New("mkdir failed") + } + + m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")} + if err := m.ensureKnownHostsFile(); err == nil { + t.Fatal("expected mkdir error") + } +} + +func TestEnsureKnownHostsFileStatError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + statFn = func(string) (os.FileInfo, error) { + return nil, errors.New("stat failed") + } + + m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")} + if err := m.ensureKnownHostsFile(); err == nil { + t.Fatal("expected stat error") + } +} + +func TestEnsureKnownHostsFileCreateError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + statFn = func(string) (os.FileInfo, error) { + return nil, os.ErrNotExist + } + openFileFn = func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("open failed") + } + + m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")} + if err := m.ensureKnownHostsFile(); err == nil { + t.Fatal("expected create error") + } +} + +func TestAppendHostKeyOpenError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + appendOpenFileFn = func(string) (io.WriteCloser, error) { + return nil, errors.New("open failed") + } + + if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil { + t.Fatal("expected open error") + } +} + +func TestAppendHostKeyWriteError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + appendOpenFileFn = func(string) (io.WriteCloser, error) { + return errWriteCloser{err: errors.New("write failed")}, nil + } + + if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil { + t.Fatal("expected write error") + } +} + +func TestNormalizeHostEntryWithComment(t *testing.T) { + entry := []byte("example.com ssh-ed25519 AAAA comment here") + normalized, keyType, err := normalizeHostEntry("example.com", entry) + if err != nil { + t.Fatalf("normalizeHostEntry error: %v", err) + } + if keyType != "ssh-ed25519" { + t.Fatalf("expected key type ssh-ed25519, got %s", keyType) + } + if string(normalized) != "example.com ssh-ed25519 AAAA comment here" { + t.Fatalf("unexpected normalized entry: %s", string(normalized)) + } +} + +func TestFindHostKeyLineNotExists(t *testing.T) { + line, err := findHostKeyLine(filepath.Join(t.TempDir(), "missing"), "example.com", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if line != "" { + t.Fatalf("expected empty line, got %q", line) + } +} + +func TestFindHostKeyLineOpenError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + openFn = func(string) (*os.File, error) { + return nil, errors.New("open failed") + } + + if _, err := findHostKeyLine("ignored", "example.com", ""); err == nil { + t.Fatal("expected open error") + } +} + +func TestFindHostKeyLineScannerError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + longLine := strings.Repeat("a", 70000) + if err := os.WriteFile(path, []byte(longLine+"\n"), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + + if _, err := findHostKeyLine(path, "example.com", ""); err == nil { + t.Fatal("expected scanner error") + } +} + +func TestHostLineMatchesSkips(t *testing.T) { + if hostLineMatches("example.com", "") { + t.Fatal("expected empty line to be false") + } + if hostLineMatches("example.com", "# comment") { + t.Fatal("expected comment line to be false") + } + if hostLineMatches("example.com", "|1|hash|salt ssh-ed25519 AAAA") { + t.Fatal("expected hashed entry to be false") + } +} + +func TestDefaultKeyscanSuccess(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) { + return []byte("example.com ssh-ed25519 AAAA"), nil + } + + out, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second) + if err != nil { + t.Fatalf("defaultKeyscan error: %v", err) + } + if string(out) != "example.com ssh-ed25519 AAAA" { + t.Fatalf("unexpected output: %s", string(out)) + } +} + +func TestDefaultKeyscanError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) { + return []byte("boom"), errors.New("scan failed") + } + + if _, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second); err == nil { + t.Fatal("expected error") + } +} + +func TestEnsureWithEntriesDefaultsPort(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + + mgr, err := NewManager(path) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "example.com", 0, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err != nil { + t.Fatalf("EnsureWithEntries: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if got := strings.TrimSpace(string(data)); got != "example.com ssh-ed25519 AAAA" { + t.Fatalf("unexpected known_hosts contents: %s", got) + } +} + +func TestEnsureWithEntriesAppendError(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + appendOpenFileFn = func(string) (io.WriteCloser, error) { + return nil, errors.New("open failed") + } + + mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts")) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil { + t.Fatal("expected append error") + } +} + +func TestAppendHostKeySkipsEmptyEntry(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + buf := &bufferWriteCloser{} + appendOpenFileFn = func(string) (io.WriteCloser, error) { + return buf, nil + } + + if err := appendHostKey("ignored", [][]byte{nil, []byte("example.com ssh-ed25519 AAAA")}); err != nil { + t.Fatalf("appendHostKey error: %v", err) + } + if !strings.Contains(buf.String(), "example.com ssh-ed25519 AAAA") { + t.Fatalf("expected entry to be written, got %q", buf.String()) + } +} + +func TestFindHostKeyLineSkipsInvalidLines(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + contents := strings.Join([]string{ + "other.com ssh-ed25519 AAAA", + "example.com ssh-ed25519", + "example.com ssh-ed25519 AAAA", + }, "\n") + "\n" + if err := os.WriteFile(path, []byte(contents), 0600); err != nil { + t.Fatalf("failed to write file: %v", err) + } + + line, err := findHostKeyLine(path, "example.com", "ssh-ed25519") + if err != nil { + t.Fatalf("findHostKeyLine error: %v", err) + } + if line != "example.com ssh-ed25519 AAAA" { + t.Fatalf("unexpected line: %q", line) + } +} + +func TestDefaultKeyscanArgs(t *testing.T) { + t.Cleanup(resetKnownHostsFns) + + var gotArgs []string + keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) { + gotArgs = append([]string{}, args...) + return []byte("ok"), nil + } + + if _, err := defaultKeyscan(context.Background(), "example.com", 0, 0); err != nil { + t.Fatalf("defaultKeyscan error: %v", err) + } + for _, arg := range gotArgs { + if arg == "-p" { + t.Fatal("did not expect -p for default port") + } + } + if len(gotArgs) < 3 || gotArgs[len(gotArgs)-1] != "example.com" { + t.Fatalf("unexpected args: %v", gotArgs) + } + + keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) { + gotArgs = append([]string{}, args...) + return []byte("ok"), nil + } + if _, err := defaultKeyscan(context.Background(), "example.com", 2222, time.Second); err != nil { + t.Fatalf("defaultKeyscan error: %v", err) + } + hasPort := false + for i := 0; i < len(gotArgs)-1; i++ { + if gotArgs[i] == "-p" && gotArgs[i+1] == "2222" { + hasPort = true + break + } + } + if !hasPort { + t.Fatalf("expected -p 2222 in args, got %v", gotArgs) + } +} + +func TestKeyscanCmdRunnerDefault(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("ssh-keyscan helper script requires sh") + } + t.Cleanup(resetKnownHostsFns) + + dir := t.TempDir() + scriptPath := filepath.Join(dir, "ssh-keyscan") + script := []byte("#!/bin/sh\necho example.com ssh-ed25519 AAAA\n") + if err := os.WriteFile(scriptPath, script, 0700); err != nil { + t.Fatalf("failed to write script: %v", err) + } + + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", dir+string(os.PathListSeparator)+oldPath); err != nil { + t.Fatalf("failed to set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + output, err := keyscanCmdRunner(context.Background(), "example.com") + if err != nil { + t.Fatalf("keyscanCmdRunner error: %v", err) + } + if strings.TrimSpace(string(output)) != "example.com ssh-ed25519 AAAA" { + t.Fatalf("unexpected output: %s", string(output)) + } +} + +type errWriteCloser struct { + err error +} + +func (e errWriteCloser) Write(p []byte) (int, error) { + return 0, e.err +} + +func (e errWriteCloser) Close() error { + return nil +} + +type bufferWriteCloser struct { + bytes.Buffer +} + +func (b *bufferWriteCloser) Close() error { + return nil +} diff --git a/internal/system/container.go b/internal/system/container.go index 96bd6813d..f34aa2e24 100644 --- a/internal/system/container.go +++ b/internal/system/container.go @@ -5,6 +5,13 @@ import ( "strings" ) +var ( + envGetFn = os.Getenv + statFn = os.Stat + readFileFn = os.ReadFile + hostnameFn = os.Hostname +) + var containerMarkers = []string{ "docker", "lxc", @@ -19,24 +26,24 @@ var containerMarkers = []string{ // InContainer reports whether Pulse is running inside a containerised environment. func InContainer() bool { // Allow operators to force container behaviour when automatic detection falls short. - if isTruthy(os.Getenv("PULSE_FORCE_CONTAINER")) { + if isTruthy(envGetFn("PULSE_FORCE_CONTAINER")) { return true } - if _, err := os.Stat("/.dockerenv"); err == nil { + if _, err := statFn("/.dockerenv"); err == nil { return true } - if _, err := os.Stat("/run/.containerenv"); err == nil { + if _, err := statFn("/run/.containerenv"); err == nil { return true } // Check common environment hints provided by systemd/nspawn, LXC, etc. - if val := strings.ToLower(strings.TrimSpace(os.Getenv("container"))); val != "" && val != "host" { + if val := strings.ToLower(strings.TrimSpace(envGetFn("container"))); val != "" && val != "host" { return true } // Some distros expose the container hint through PID 1's environment. - if data, err := os.ReadFile("/proc/1/environ"); err == nil { + if data, err := readFileFn("/proc/1/environ"); err == nil { lower := strings.ToLower(string(data)) if strings.Contains(lower, "container=") && !strings.Contains(lower, "container=host") { return true @@ -44,7 +51,7 @@ func InContainer() bool { } // Fall back to cgroup inspection which covers older Docker/LXC setups. - if data, err := os.ReadFile("/proc/1/cgroup"); err == nil { + if data, err := readFileFn("/proc/1/cgroup"); err == nil { content := strings.ToLower(string(data)) for _, marker := range containerMarkers { if strings.Contains(content, marker) { @@ -60,7 +67,7 @@ func InContainer() bool { // Returns empty string if not in Docker or name cannot be determined. func DetectDockerContainerName() string { // Method 1: Check hostname (Docker uses container ID or name as hostname) - if hostname, err := os.Hostname(); err == nil && hostname != "" { + if hostname, err := hostnameFn(); err == nil && hostname != "" { // Docker hostnames are either short container ID (12 chars) or custom name // If it looks like a container ID (hex), skip it - user needs to use name if !isHexString(hostname) || len(hostname) > 12 { @@ -69,7 +76,7 @@ func DetectDockerContainerName() string { } // Method 2: Try reading from /proc/self/cgroup - if data, err := os.ReadFile("/proc/self/cgroup"); err == nil { + if data, err := readFileFn("/proc/self/cgroup"); err == nil { // Look for patterns like: 0::/docker/ // But we can't get name from cgroup, only ID _ = data // placeholder for future enhancement @@ -91,7 +98,7 @@ func isHexString(s string) bool { // Returns empty string if not in an LXC container or CTID cannot be determined. func DetectLXCCTID() string { // Method 1: Parse /proc/1/cgroup for LXC container ID - if data, err := os.ReadFile("/proc/1/cgroup"); err == nil { + if data, err := readFileFn("/proc/1/cgroup"); err == nil { lines := strings.Split(string(data), "\n") for _, line := range lines { // Look for patterns like: 0::/lxc/123 or 0::/lxc.payload.123 @@ -119,7 +126,7 @@ func DetectLXCCTID() string { } // Method 2: Check hostname (some LXC containers use CTID as hostname) - if hostname, err := os.Hostname(); err == nil && isNumeric(hostname) { + if hostname, err := hostnameFn(); err == nil && isNumeric(hostname) { return hostname } diff --git a/internal/system/container_test.go b/internal/system/container_test.go index f7e160a22..6aa28f060 100644 --- a/internal/system/container_test.go +++ b/internal/system/container_test.go @@ -1,9 +1,18 @@ package system import ( + "errors" + "os" "testing" ) +func resetSystemFns() { + envGetFn = os.Getenv + statFn = os.Stat + readFileFn = os.ReadFile + hostnameFn = os.Hostname +} + func TestIsHexString(t *testing.T) { tests := []struct { input string @@ -133,20 +142,226 @@ func TestContainerMarkers(t *testing.T) { } func TestInContainer(t *testing.T) { - // This is an integration test that depends on the runtime environment - // We can't mock the file system easily, but we can verify it doesn't panic - result := InContainer() - t.Logf("InContainer() = %v (depends on test environment)", result) + t.Run("Forced", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(key string) string { + if key == "PULSE_FORCE_CONTAINER" { + return "true" + } + return "" + } + + if !InContainer() { + t.Fatal("expected forced container") + } + }) + + t.Run("DockerEnvFile", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(string) string { return "" } + statFn = func(path string) (os.FileInfo, error) { + if path == "/.dockerenv" { + return nil, nil + } + return nil, errors.New("missing") + } + readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") } + + if !InContainer() { + t.Fatal("expected docker env file") + } + }) + + t.Run("ContainerEnvFile", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(string) string { return "" } + statFn = func(path string) (os.FileInfo, error) { + if path == "/run/.containerenv" { + return nil, nil + } + return nil, errors.New("missing") + } + readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") } + + if !InContainer() { + t.Fatal("expected container env file") + } + }) + + t.Run("ContainerEnvVar", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(key string) string { + if key == "container" { + return "lxc" + } + return "" + } + statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") } + readFileFn = func(string) ([]byte, error) { return nil, errors.New("missing") } + + if !InContainer() { + t.Fatal("expected container env var") + } + }) + + t.Run("ProcEnviron", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(string) string { return "" } + statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") } + readFileFn = func(path string) ([]byte, error) { + if path == "/proc/1/environ" { + return []byte("container=lxc\x00"), nil + } + return nil, errors.New("missing") + } + + if !InContainer() { + t.Fatal("expected container environ") + } + }) + + t.Run("CgroupMarker", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(string) string { return "" } + statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") } + readFileFn = func(path string) ([]byte, error) { + if path == "/proc/1/environ" { + return []byte("container=host"), nil + } + if path == "/proc/1/cgroup" { + return []byte("0::/docker/abc"), nil + } + return nil, errors.New("missing") + } + + if !InContainer() { + t.Fatal("expected container cgroup marker") + } + }) + + t.Run("NotContainer", func(t *testing.T) { + t.Cleanup(resetSystemFns) + envGetFn = func(string) string { return "" } + statFn = func(string) (os.FileInfo, error) { return nil, errors.New("missing") } + readFileFn = func(path string) ([]byte, error) { + if path == "/proc/1/environ" { + return []byte("container=host"), nil + } + if path == "/proc/1/cgroup" { + return []byte("0::/user.slice"), nil + } + return nil, errors.New("missing") + } + + if InContainer() { + t.Fatal("expected non-container") + } + }) } func TestDetectDockerContainerName(t *testing.T) { - // This is an integration test that depends on the runtime environment - result := DetectDockerContainerName() - t.Logf("DetectDockerContainerName() = %q (depends on test environment)", result) + t.Run("HostnameName", func(t *testing.T) { + t.Cleanup(resetSystemFns) + hostnameFn = func() (string, error) { return "my-container", nil } + + if got := DetectDockerContainerName(); got != "my-container" { + t.Fatalf("expected hostname name, got %q", got) + } + }) + + t.Run("HostnameHexShort", func(t *testing.T) { + t.Cleanup(resetSystemFns) + hostnameFn = func() (string, error) { return "abcdef123456", nil } + readFileFn = func(string) ([]byte, error) { return []byte("0::/docker/abcdef"), nil } + + if got := DetectDockerContainerName(); got != "" { + t.Fatalf("expected empty name, got %q", got) + } + }) + + t.Run("HostnameHexLong", func(t *testing.T) { + t.Cleanup(resetSystemFns) + hostnameFn = func() (string, error) { return "abcdef1234567890abcdef1234567890", nil } + + if got := DetectDockerContainerName(); got == "" { + t.Fatal("expected hostname for long hex") + } + }) + + t.Run("HostnameError", func(t *testing.T) { + t.Cleanup(resetSystemFns) + hostnameFn = func() (string, error) { return "", errors.New("fail") } + readFileFn = func(string) ([]byte, error) { return []byte("0::/docker/abcdef"), nil } + + if got := DetectDockerContainerName(); got != "" { + t.Fatalf("expected empty name, got %q", got) + } + }) } func TestDetectLXCCTID(t *testing.T) { - // This is an integration test that depends on the runtime environment - result := DetectLXCCTID() - t.Logf("DetectLXCCTID() = %q (depends on test environment)", result) + t.Run("FromCgroupLXC", func(t *testing.T) { + t.Cleanup(resetSystemFns) + readFileFn = func(path string) ([]byte, error) { + if path == "/proc/1/cgroup" { + return []byte("0::/lxc/123"), nil + } + return nil, errors.New("missing") + } + hostnameFn = func() (string, error) { return "999", nil } + + if got := DetectLXCCTID(); got != "123" { + t.Fatalf("expected CTID 123, got %q", got) + } + }) + + t.Run("FromCgroupPayload", func(t *testing.T) { + t.Cleanup(resetSystemFns) + readFileFn = func(path string) ([]byte, error) { + if path == "/proc/1/cgroup" { + return []byte("0::/lxc.payload.456"), nil + } + return nil, errors.New("missing") + } + hostnameFn = func() (string, error) { return "999", nil } + + if got := DetectLXCCTID(); got != "456" { + t.Fatalf("expected CTID 456, got %q", got) + } + }) + + t.Run("FromMachineLXC", func(t *testing.T) { + t.Cleanup(resetSystemFns) + readFileFn = func(path string) ([]byte, error) { + if path == "/proc/1/cgroup" { + return []byte("0::/machine.slice/machine-lxc-789"), nil + } + return nil, errors.New("missing") + } + hostnameFn = func() (string, error) { return "999", nil } + + if got := DetectLXCCTID(); got != "789" { + t.Fatalf("expected CTID 789, got %q", got) + } + }) + + t.Run("FromHostname", func(t *testing.T) { + t.Cleanup(resetSystemFns) + readFileFn = func(string) ([]byte, error) { return []byte("0::/user.slice"), nil } + hostnameFn = func() (string, error) { return "321", nil } + + if got := DetectLXCCTID(); got != "321" { + t.Fatalf("expected CTID 321, got %q", got) + } + }) + + t.Run("NoMatch", func(t *testing.T) { + t.Cleanup(resetSystemFns) + readFileFn = func(string) ([]byte, error) { return []byte("0::/user.slice"), nil } + hostnameFn = func() (string, error) { return "host", nil } + + if got := DetectLXCCTID(); got != "" { + t.Fatalf("expected empty CTID, got %q", got) + } + }) } diff --git a/internal/tempproxy/client.go b/internal/tempproxy/client.go index 5a17e0271..9da0d7618 100644 --- a/internal/tempproxy/client.go +++ b/internal/tempproxy/client.go @@ -24,6 +24,13 @@ const ( maxRetries = 3 ) +var statFn = os.Stat + +var dialContextFn = func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) { + dialer := net.Dialer{Timeout: timeout} + return dialer.DialContext(ctx, network, address) +} + // ErrorType classifies proxy errors for better error handling type ErrorType int @@ -66,9 +73,9 @@ type Client struct { func NewClient() *Client { socketPath := os.Getenv("PULSE_SENSOR_PROXY_SOCKET") if socketPath == "" { - if _, err := os.Stat(defaultSocketPath); err == nil { + if _, err := statFn(defaultSocketPath); err == nil { socketPath = defaultSocketPath - } else if _, err := os.Stat(containerSocketPath); err == nil { + } else if _, err := statFn(containerSocketPath); err == nil { socketPath = containerSocketPath } else { socketPath = defaultSocketPath @@ -83,7 +90,7 @@ func NewClient() *Client { // IsAvailable checks if the proxy is running and accessible func (c *Client) IsAvailable() bool { - _, err := os.Stat(c.socketPath) + _, err := statFn(c.socketPath) return err == nil } @@ -312,13 +319,8 @@ func (c *Client) callWithContext(ctx context.Context, method string, params map[ // callOnce sends a single RPC request without retries func (c *Client) callOnce(ctx context.Context, method string, params map[string]interface{}) (*RPCResponse, error) { - // Create a dialer with context - dialer := net.Dialer{ - Timeout: c.timeout, - } - // Connect to unix socket with context - conn, err := dialer.DialContext(ctx, "unix", c.socketPath) + conn, err := dialContextFn(ctx, "unix", c.socketPath, c.timeout) if err != nil { return nil, fmt.Errorf("failed to connect to proxy: %w", err) } @@ -366,10 +368,6 @@ func (c *Client) GetStatus() (map[string]interface{}, error) { return nil, err } - if !resp.Success { - return nil, fmt.Errorf("proxy error: %s", resp.Error) - } - return resp.Data, nil } @@ -380,13 +378,6 @@ func (c *Client) RegisterNodes() ([]map[string]interface{}, error) { return nil, err } - if !resp.Success { - if proxyErr := classifyError(nil, resp.Error); proxyErr != nil { - return nil, proxyErr - } - return nil, fmt.Errorf("proxy error: %s", resp.Error) - } - // Extract nodes array from data nodesRaw, ok := resp.Data["nodes"] if !ok { @@ -422,18 +413,6 @@ func (c *Client) GetTemperature(nodeHost string) (string, error) { return "", err } - if !resp.Success { - if proxyErr := classifyError(nil, resp.Error); proxyErr != nil { - return "", proxyErr - } - return "", &ProxyError{ - Type: ErrorTypeUnknown, - Message: resp.Error, - Retryable: false, - Wrapped: fmt.Errorf("%s", resp.Error), - } - } - // Extract temperature JSON string tempRaw, ok := resp.Data["temperature"] if !ok { @@ -455,17 +434,9 @@ func (c *Client) RequestCleanup(host string) error { params["host"] = host } - resp, err := c.call("request_cleanup", params) - if err != nil { + if _, err := c.call("request_cleanup", params); err != nil { return err } - if !resp.Success { - if resp.Error != "" { - return fmt.Errorf("proxy error: %s", resp.Error) - } - return fmt.Errorf("proxy rejected cleanup request") - } - return nil } diff --git a/internal/tempproxy/client_coverage_test.go b/internal/tempproxy/client_coverage_test.go new file mode 100644 index 000000000..fc27437fd --- /dev/null +++ b/internal/tempproxy/client_coverage_test.go @@ -0,0 +1,488 @@ +package tempproxy + +import ( + "context" + "encoding/json" + "errors" + "net" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +type writeErrConn struct { + net.Conn +} + +func (c writeErrConn) Write(p []byte) (int, error) { + return 0, errors.New("write failed") +} + +func startUnixServer(t *testing.T, handler func(net.Conn)) (string, func()) { + t.Helper() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "proxy.sock") + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen: %v", err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + for { + conn, err := ln.Accept() + if err != nil { + return + } + go handler(conn) + } + }() + + cleanup := func() { + ln.Close() + <-done + } + return socketPath, cleanup +} + +func startRPCServer(t *testing.T, handler func(RPCRequest) RPCResponse) (string, func()) { + t.Helper() + return startUnixServer(t, func(conn net.Conn) { + defer conn.Close() + decoder := json.NewDecoder(conn) + var req RPCRequest + if err := decoder.Decode(&req); err != nil { + return + } + resp := handler(req) + encoder := json.NewEncoder(conn) + _ = encoder.Encode(resp) + }) +} + +func TestNewClientSocketSelection(t *testing.T) { + origStat := statFn + t.Cleanup(func() { statFn = origStat }) + + t.Run("EnvOverride", func(t *testing.T) { + origEnv := os.Getenv("PULSE_SENSOR_PROXY_SOCKET") + t.Cleanup(func() { + if origEnv == "" { + os.Unsetenv("PULSE_SENSOR_PROXY_SOCKET") + } else { + os.Setenv("PULSE_SENSOR_PROXY_SOCKET", origEnv) + } + }) + os.Setenv("PULSE_SENSOR_PROXY_SOCKET", "/tmp/custom.sock") + statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist } + + client := NewClient() + if client.socketPath != "/tmp/custom.sock" { + t.Fatalf("expected env socket, got %q", client.socketPath) + } + }) + + t.Run("DefaultExists", func(t *testing.T) { + statFn = func(path string) (os.FileInfo, error) { + if path == defaultSocketPath { + return nil, nil + } + return nil, os.ErrNotExist + } + + client := NewClient() + if client.socketPath != defaultSocketPath { + t.Fatalf("expected default socket, got %q", client.socketPath) + } + }) + + t.Run("ContainerExists", func(t *testing.T) { + statFn = func(path string) (os.FileInfo, error) { + if path == containerSocketPath { + return nil, nil + } + return nil, os.ErrNotExist + } + + client := NewClient() + if client.socketPath != containerSocketPath { + t.Fatalf("expected container socket, got %q", client.socketPath) + } + }) + + t.Run("FallbackDefault", func(t *testing.T) { + statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist } + client := NewClient() + if client.socketPath != defaultSocketPath { + t.Fatalf("expected fallback socket, got %q", client.socketPath) + } + }) +} + +func TestClientIsAvailable(t *testing.T) { + origStat := statFn + t.Cleanup(func() { statFn = origStat }) + + client := &Client{socketPath: "/tmp/socket"} + statFn = func(string) (os.FileInfo, error) { return nil, nil } + if !client.IsAvailable() { + t.Fatalf("expected available") + } + + statFn = func(string) (os.FileInfo, error) { return nil, os.ErrNotExist } + if client.IsAvailable() { + t.Fatalf("expected unavailable") + } +} + +func TestCallOnceSuccessAndDeadline(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + if req.Method != "ping" { + t.Fatalf("unexpected method: %s", req.Method) + } + return RPCResponse{Success: true} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + + resp, err := client.callOnce(context.Background(), "ping", nil) + if err != nil || resp == nil || !resp.Success { + t.Fatalf("expected success: resp=%v err=%v", resp, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + resp, err = client.callOnce(ctx, "ping", nil) + if err != nil || resp == nil || !resp.Success { + t.Fatalf("expected success with deadline: resp=%v err=%v", resp, err) + } +} + +func TestCallOnceErrors(t *testing.T) { + t.Run("ConnectError", func(t *testing.T) { + client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond} + if _, err := client.callOnce(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected connect error") + } + }) + + t.Run("EncodeError", func(t *testing.T) { + socketPath, cleanup := startUnixServer(t, func(conn net.Conn) { + conn.Close() + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.callOnce(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected encode error") + } + }) + + t.Run("EncodeErrorDial", func(t *testing.T) { + origDial := dialContextFn + t.Cleanup(func() { dialContextFn = origDial }) + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { _ = serverConn.Close() }) + + dialContextFn = func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) { + return writeErrConn{Conn: clientConn}, nil + } + + client := &Client{socketPath: "/tmp/unused.sock", timeout: 50 * time.Millisecond} + if _, err := client.callOnce(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected encode error") + } + }) + + t.Run("DecodeError", func(t *testing.T) { + socketPath, cleanup := startUnixServer(t, func(conn net.Conn) { + defer conn.Close() + _, _ = conn.Write([]byte("{")) + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.callOnce(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected decode error") + } + }) +} + +func TestCallWithContextBehavior(t *testing.T) { + t.Run("Success", func(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + resp, err := client.callWithContext(context.Background(), "ping", nil) + if err != nil || resp == nil || !resp.Success { + t.Fatalf("expected success: resp=%v err=%v", resp, err) + } + }) + + t.Run("NonRetryable", func(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "unauthorized"} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected auth error") + } + }) + + t.Run("CancelledBeforeRetry", func(t *testing.T) { + client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if _, err := client.callWithContext(ctx, "ping", nil); err == nil { + t.Fatalf("expected cancelled error") + } + }) + + t.Run("CancelledDuringBackoff", func(t *testing.T) { + client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond} + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + if _, err := client.callWithContext(ctx, "ping", nil); err == nil { + t.Fatalf("expected backoff cancel error") + } + }) + + t.Run("MaxRetriesExhausted", func(t *testing.T) { + client := &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond} + if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected retry error") + } + }) +} + +func TestGetStatus(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{"ok": true}} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + data, err := client.GetStatus() + if err != nil || data["ok"] != true { + t.Fatalf("unexpected status: %v err=%v", data, err) + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "boom"} + }) + t.Cleanup(cleanup) + + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.GetStatus(); err == nil { + t.Fatalf("expected status error") + } +} + +func TestRegisterNodes(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{ + "nodes": []interface{}{map[string]interface{}{"name": "node1"}}, + }} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + nodes, err := client.RegisterNodes() + if err != nil || len(nodes) != 1 || nodes[0]["name"] != "node1" { + t.Fatalf("unexpected nodes: %v err=%v", nodes, err) + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "unauthorized"} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.RegisterNodes(); err == nil { + t.Fatalf("expected proxy error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{}} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.RegisterNodes(); err == nil { + t.Fatalf("expected missing nodes error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{"nodes": "bad"}} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.RegisterNodes(); err == nil { + t.Fatalf("expected nodes type error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{"nodes": []interface{}{"bad"}}} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.RegisterNodes(); err == nil { + t.Fatalf("expected node map error") + } +} + +func TestGetTemperature(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{"temperature": "42"}} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + temp, err := client.GetTemperature("node1") + if err != nil || temp != "42" { + t.Fatalf("unexpected temp: %q err=%v", temp, err) + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "unauthorized"} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.GetTemperature("node1"); err == nil { + t.Fatalf("expected proxy error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "boom"} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.GetTemperature("node1"); err == nil { + t.Fatalf("expected unknown error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{}} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.GetTemperature("node1"); err == nil { + t.Fatalf("expected missing temperature error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true, Data: map[string]interface{}{"temperature": 12}} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.GetTemperature("node1"); err == nil { + t.Fatalf("expected type error") + } +} + +func TestRequestCleanup(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + if req.Method != "request_cleanup" { + t.Fatalf("unexpected method: %s", req.Method) + } + if req.Params["host"] != "node1" { + t.Fatalf("unexpected params: %v", req.Params) + } + return RPCResponse{Success: true} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if err := client.RequestCleanup("node1"); err != nil { + t.Fatalf("unexpected cleanup error: %v", err) + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "boom"} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if err := client.RequestCleanup(""); err == nil { + t.Fatalf("expected proxy error") + } + + socketPath, cleanup = startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: ""} + }) + t.Cleanup(cleanup) + client = &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if err := client.RequestCleanup(""); err == nil { + t.Fatalf("expected rejected error") + } + + client = &Client{socketPath: "/tmp/missing.sock", timeout: 10 * time.Millisecond} + if err := client.RequestCleanup(""); err == nil { + t.Fatalf("expected call error") + } +} + +func TestCallWrapper(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: true} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + resp, err := client.call("ping", nil) + if err != nil || resp == nil || !resp.Success { + t.Fatalf("expected call success: resp=%v err=%v", resp, err) + } +} + +func TestCallWithContextRetryableResponseError(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: "ssh timeout"} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected retry exhaustion") + } +} + +func TestCallWithContextSuccessOnRetry(t *testing.T) { + count := 0 + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + count++ + if count == 1 { + return RPCResponse{Success: false, Error: "ssh timeout"} + } + return RPCResponse{Success: true} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + resp, err := client.callWithContext(context.Background(), "ping", nil) + if err != nil || resp == nil || !resp.Success { + t.Fatalf("expected retry success: resp=%v err=%v", resp, err) + } +} + +func TestCallWithContextRespErrorString(t *testing.T) { + socketPath, cleanup := startRPCServer(t, func(req RPCRequest) RPCResponse { + return RPCResponse{Success: false, Error: strings.Repeat("x", 1)} + }) + t.Cleanup(cleanup) + + client := &Client{socketPath: socketPath, timeout: 50 * time.Millisecond} + if _, err := client.callWithContext(context.Background(), "ping", nil); err == nil { + t.Fatalf("expected error") + } +} diff --git a/internal/tempproxy/client_test.go b/internal/tempproxy/client_test.go index 54c0c0bc0..5191c3050 100644 --- a/internal/tempproxy/client_test.go +++ b/internal/tempproxy/client_test.go @@ -200,6 +200,12 @@ func TestContains(t *testing.T) { substrs: []string{"ssh error"}, expected: true, }, + { + name: "case insensitive upper substr", + s: "connection refused", + substrs: []string{"REFUSED"}, + expected: true, + }, { name: "empty string", s: "", diff --git a/internal/tempproxy/http_client.go b/internal/tempproxy/http_client.go index a56ffd268..a1d044e91 100644 --- a/internal/tempproxy/http_client.go +++ b/internal/tempproxy/http_client.go @@ -80,7 +80,11 @@ func (c *HTTPClient) GetTemperature(nodeHost string) (string, error) { req.Header.Set("Accept", "application/json") // Execute request with retries - var lastErr error + var lastErr error = &ProxyError{ + Type: ErrorTypeUnknown, + Message: "all retry attempts failed", + Retryable: false, + } for attempt := 0; attempt < maxRetries; attempt++ { if attempt > 0 { backoff := calculateBackoff(attempt) @@ -167,15 +171,7 @@ func (c *HTTPClient) GetTemperature(nodeHost string) (string, error) { } // All retries exhausted - if lastErr != nil { - return "", lastErr - } - - return "", &ProxyError{ - Type: ErrorTypeUnknown, - Message: "all retry attempts failed", - Retryable: false, - } + return "", lastErr } // HealthCheck calls the proxy /health endpoint to verify connectivity. diff --git a/internal/tempproxy/http_client_test.go b/internal/tempproxy/http_client_test.go index afe55c152..6bfe3eed5 100644 --- a/internal/tempproxy/http_client_test.go +++ b/internal/tempproxy/http_client_test.go @@ -2,6 +2,8 @@ package tempproxy import ( "encoding/json" + "errors" + "io" "net/http" "net/http/httptest" "strings" @@ -360,6 +362,55 @@ func TestHTTPClient_GetTemperature_RetryOnTransportError(t *testing.T) { } } +func TestHTTPClient_GetTemperature_RequestBuildError(t *testing.T) { + client := NewHTTPClient("http://[::1", "token") + _, err := client.GetTemperature("node1") + + if err == nil { + t.Fatal("Expected error for request creation") + } + + proxyErr, ok := err.(*ProxyError) + if !ok { + t.Fatalf("Expected *ProxyError, got %T", err) + } + if proxyErr.Type != ErrorTypeTransport { + t.Errorf("Type = %v, want ErrorTypeTransport", proxyErr.Type) + } + if !strings.Contains(proxyErr.Message, "failed to create HTTP request") { + t.Errorf("Message = %q, want request creation failure", proxyErr.Message) + } +} + +func TestHTTPClient_GetTemperature_ReadBodyError(t *testing.T) { + attempts := 0 + client := NewHTTPClient("https://example.com", "token") + client.httpClient.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + if attempts == 1 { + return &http.Response{ + StatusCode: http.StatusOK, + Body: errReadCloser{}, + Header: make(http.Header), + }, nil + } + body := `{"node":"node1","temperature":"{}"}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + }) + + _, err := client.GetTemperature("node1") + if err != nil { + t.Fatalf("Expected success after retry, got %v", err) + } + if attempts < 2 { + t.Fatalf("Expected retry after read error, got %d attempts", attempts) + } +} + func TestHTTPClient_HealthCheck_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/health" { @@ -398,6 +449,26 @@ func TestHTTPClient_HealthCheck_NotConfigured(t *testing.T) { } } +func TestHTTPClient_HealthCheck_RequestBuildError(t *testing.T) { + client := NewHTTPClient("http://[::1", "token") + err := client.HealthCheck() + + if err == nil { + t.Fatal("Expected error for request creation") + } + + proxyErr, ok := err.(*ProxyError) + if !ok { + t.Fatalf("Expected *ProxyError, got %T", err) + } + if proxyErr.Type != ErrorTypeTransport { + t.Errorf("Type = %v, want ErrorTypeTransport", proxyErr.Type) + } + if !strings.Contains(proxyErr.Message, "failed to create HTTP request") { + t.Errorf("Message = %q, want request creation failure", proxyErr.Message) + } +} + func TestHTTPClient_HealthCheck_ServerError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) @@ -493,3 +564,19 @@ func TestHTTPClient_Fields(t *testing.T) { t.Errorf("timeout = %v, want 60s", client.timeout) } } + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type errReadCloser struct{} + +func (errReadCloser) Read(p []byte) (int, error) { + return 0, errors.New("read failed") +} + +func (errReadCloser) Close() error { + return nil +} diff --git a/internal/updatedetection/manager_coverage_test.go b/internal/updatedetection/manager_coverage_test.go new file mode 100644 index 000000000..ea0fe6039 --- /dev/null +++ b/internal/updatedetection/manager_coverage_test.go @@ -0,0 +1,326 @@ +package updatedetection + +import ( + "context" + "testing" + "time" + + "github.com/rs/zerolog" +) + +func TestDefaultManagerConfig(t *testing.T) { + cfg := DefaultManagerConfig() + if !cfg.Enabled { + t.Fatalf("expected enabled by default") + } + if cfg.CheckInterval != 6*time.Hour { + t.Fatalf("expected check interval 6h, got %v", cfg.CheckInterval) + } + if cfg.AlertDelayHours != 24 { + t.Fatalf("expected alert delay 24, got %d", cfg.AlertDelayHours) + } + if !cfg.EnableDockerUpdates { + t.Fatalf("expected docker updates enabled by default") + } +} + +func TestManagerConfigAccessors(t *testing.T) { + cfg := ManagerConfig{ + Enabled: false, + CheckInterval: time.Minute, + AlertDelayHours: 12, + EnableDockerUpdates: false, + } + mgr := NewManager(cfg, zerolog.Nop()) + if mgr.Enabled() { + t.Fatalf("expected disabled manager") + } + mgr.SetEnabled(true) + if !mgr.Enabled() { + t.Fatalf("expected enabled manager after SetEnabled") + } + if mgr.AlertDelayHours() != 12 { + t.Fatalf("expected alert delay 12, got %d", mgr.AlertDelayHours()) + } +} + +func TestManagerProcessDockerContainerUpdate(t *testing.T) { + now := time.Now() + status := &ContainerUpdateStatus{ + UpdateAvailable: true, + CurrentDigest: "sha256:old", + LatestDigest: "sha256:new", + LastChecked: now, + Error: "boom", + } + + t.Run("Disabled", func(t *testing.T) { + cfg := DefaultManagerConfig() + cfg.Enabled = false + mgr := NewManager(cfg, zerolog.Nop()) + mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status) + if mgr.store.Count() != 0 { + t.Fatalf("expected no updates when disabled") + } + }) + + t.Run("DockerUpdatesDisabled", func(t *testing.T) { + cfg := DefaultManagerConfig() + cfg.EnableDockerUpdates = false + mgr := NewManager(cfg, zerolog.Nop()) + mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status) + if mgr.store.Count() != 0 { + t.Fatalf("expected no updates when docker updates disabled") + } + }) + + t.Run("NilStatus", func(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", nil) + if mgr.store.Count() != 0 { + t.Fatalf("expected no updates for nil status") + } + }) + + t.Run("NoUpdateDeletes", func(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + mgr.store.UpsertUpdate(&UpdateInfo{ + ID: "docker:host-1:container-1", + ResourceID: "container-1", + HostID: "host-1", + }) + mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", &ContainerUpdateStatus{ + UpdateAvailable: false, + LastChecked: now, + }) + if mgr.store.Count() != 0 { + t.Fatalf("expected update to be deleted when no update available") + } + }) + + t.Run("UpdateAvailable", func(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + mgr.ProcessDockerContainerUpdate("host-1", "container-1", "nginx", "nginx:latest", "sha256:old", status) + update := mgr.store.GetUpdatesForResource("container-1") + if update == nil { + t.Fatalf("expected update to be stored") + } + if update.ID != "docker:host-1:container-1" { + t.Fatalf("unexpected update ID %q", update.ID) + } + if update.Error != "boom" { + t.Fatalf("expected error to be stored") + } + if update.CurrentVersion != "nginx:latest" { + t.Fatalf("expected current version to be stored") + } + }) +} + +func TestManagerCheckImageUpdate(t *testing.T) { + ctx := context.Background() + + t.Run("Disabled", func(t *testing.T) { + cfg := DefaultManagerConfig() + cfg.Enabled = false + mgr := NewManager(cfg, zerolog.Nop()) + info, err := mgr.CheckImageUpdate(ctx, "nginx:latest", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info != nil { + t.Fatalf("expected nil info when disabled") + } + }) + + t.Run("Cached", func(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + cacheKey := "registry-1.docker.io/library/nginx:latest" + mgr.registry.cacheDigest(cacheKey, "sha256:new") + + info, err := mgr.CheckImageUpdate(ctx, "nginx:latest", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info == nil || !info.UpdateAvailable { + t.Fatalf("expected update available from cache") + } + }) +} + +func TestManagerGetUpdatesAndAccessors(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + now := time.Now() + + update1 := &UpdateInfo{ + ID: "update-1", + ResourceID: "res-1", + ResourceType: "docker", + ResourceName: "nginx", + HostID: "host-1", + Type: UpdateTypeDockerImage, + Severity: SeveritySecurity, + LastChecked: now, + } + update2 := &UpdateInfo{ + ID: "update-2", + ResourceID: "res-2", + ResourceType: "vm", + ResourceName: "vm-1", + HostID: "host-2", + Type: UpdateTypePackage, + LastChecked: now.Add(-time.Minute), + } + mgr.store.UpsertUpdate(update1) + mgr.store.UpsertUpdate(update2) + + if mgr.GetTotalCount() != 2 { + t.Fatalf("expected total count 2, got %d", mgr.GetTotalCount()) + } + + all := mgr.GetUpdates(UpdateFilters{}) + if len(all) != 2 { + t.Fatalf("expected 2 updates, got %d", len(all)) + } + + filtered := mgr.GetUpdates(UpdateFilters{HostID: "host-1"}) + if len(filtered) != 1 || filtered[0].ID != "update-1" { + t.Fatalf("expected one update for host-1") + } + + if len(mgr.GetUpdatesForHost("host-1")) != 1 { + t.Fatalf("expected 1 update for host-1") + } + if mgr.GetUpdatesForResource("res-2") == nil { + t.Fatalf("expected update for resource res-2") + } + + summary := mgr.GetSummary() + if summary["host-1"].TotalUpdates != 1 { + t.Fatalf("expected summary for host-1") + } + + mgr.AddRegistryConfig(RegistryConfig{Host: "example.com", Username: "user"}) + if _, ok := mgr.registry.configs["example.com"]; !ok { + t.Fatalf("expected registry config to be stored") + } + + mgr.DeleteUpdatesForHost("host-1") + if len(mgr.GetUpdatesForHost("host-1")) != 0 { + t.Fatalf("expected updates removed for host-1") + } +} + +func TestManagerCleanupStale(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + now := time.Now() + + mgr.store.UpsertUpdate(&UpdateInfo{ + ID: "stale", + ResourceID: "res-1", + HostID: "host-1", + LastChecked: now.Add(-2 * time.Hour), + }) + mgr.store.UpsertUpdate(&UpdateInfo{ + ID: "fresh", + ResourceID: "res-2", + HostID: "host-1", + LastChecked: now, + }) + + removed := mgr.CleanupStale(time.Hour) + if removed != 1 { + t.Fatalf("expected 1 stale update removed, got %d", removed) + } + if mgr.GetTotalCount() != 1 { + t.Fatalf("expected one update remaining") + } +} + +func TestManagerGetUpdatesReadyForAlert(t *testing.T) { + mgr := NewManager(DefaultManagerConfig(), zerolog.Nop()) + mgr.alertDelayHours = 1 + now := time.Now() + + mgr.store.UpsertUpdate(&UpdateInfo{ + ID: "ready", + ResourceID: "res-1", + HostID: "host-1", + FirstDetected: now.Add(-2 * time.Hour), + LastChecked: now, + }) + mgr.store.UpsertUpdate(&UpdateInfo{ + ID: "recent", + ResourceID: "res-2", + HostID: "host-1", + FirstDetected: now.Add(-30 * time.Minute), + LastChecked: now, + }) + mgr.store.UpsertUpdate(&UpdateInfo{ + ID: "error", + ResourceID: "res-3", + HostID: "host-1", + FirstDetected: now.Add(-2 * time.Hour), + LastChecked: now, + Error: "rate limited", + }) + + ready := mgr.GetUpdatesReadyForAlert() + if len(ready) != 1 || ready[0].ID != "ready" { + t.Fatalf("expected only ready update, got %d", len(ready)) + } +} + +func TestUpdateFilters(t *testing.T) { + update := &UpdateInfo{ + HostID: "host-1", + ResourceType: "docker", + Type: UpdateTypeDockerImage, + Severity: SeveritySecurity, + Error: "", + } + + if !(&UpdateFilters{}).IsEmpty() { + t.Fatalf("expected empty filters to be empty") + } + if (&UpdateFilters{HostID: "host-1"}).IsEmpty() { + t.Fatalf("expected non-empty filters") + } + + if (&UpdateFilters{HostID: "other"}).Matches(update) { + t.Fatalf("expected host mismatch to fail") + } + if (&UpdateFilters{ResourceType: "vm"}).Matches(update) { + t.Fatalf("expected resource type mismatch to fail") + } + if (&UpdateFilters{UpdateType: UpdateTypePackage}).Matches(update) { + t.Fatalf("expected update type mismatch to fail") + } + if (&UpdateFilters{Severity: SeverityBugfix}).Matches(update) { + t.Fatalf("expected severity mismatch to fail") + } + + hasError := true + if (&UpdateFilters{HasError: &hasError}).Matches(update) { + t.Fatalf("expected error filter to fail") + } + + update.Error = "boom" + hasError = false + if (&UpdateFilters{HasError: &hasError}).Matches(update) { + t.Fatalf("expected error=false filter to fail") + } + + update.Error = "" + hasError = false + filters := UpdateFilters{ + HostID: "host-1", + ResourceType: "docker", + UpdateType: UpdateTypeDockerImage, + Severity: SeveritySecurity, + HasError: &hasError, + } + if !filters.Matches(update) { + t.Fatalf("expected filters to match update") + } +} diff --git a/internal/updatedetection/registry_coverage_test.go b/internal/updatedetection/registry_coverage_test.go new file mode 100644 index 000000000..1df40ff74 --- /dev/null +++ b/internal/updatedetection/registry_coverage_test.go @@ -0,0 +1,465 @@ +package updatedetection + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type errorReader struct { + err error +} + +func (e *errorReader) Read(p []byte) (int, error) { + return 0, e.err +} + +func newResponse(status int, headers http.Header, body io.Reader) *http.Response { + if headers == nil { + headers = http.Header{} + } + if body == nil { + body = bytes.NewReader(nil) + } + return &http.Response{ + StatusCode: status, + Status: http.StatusText(status), + Header: headers, + Body: io.NopCloser(body), + } +} + +func TestRegistryCheckerConfig(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + if r.httpClient == nil || r.cache == nil || r.configs == nil { + t.Fatalf("expected registry checker to initialize") + } + + cfg := RegistryConfig{Host: "example.com", Username: "user", Password: "pass", Insecure: true} + r.AddRegistryConfig(cfg) + + r.mu.RLock() + stored := r.configs["example.com"] + r.mu.RUnlock() + if stored.Username != "user" || !stored.Insecure { + t.Fatalf("expected config to be stored") + } +} + +func TestRegistryCache(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.cacheDigest("key-digest", "sha256:abc") + r.cacheError("key-error", "boom") + + if r.CacheSize() != 2 { + t.Fatalf("expected cache size 2, got %d", r.CacheSize()) + } + if entry := r.getCached("key-digest"); entry == nil || entry.digest != "sha256:abc" { + t.Fatalf("expected cached digest") + } + if entry := r.getCached("missing"); entry != nil { + t.Fatalf("expected missing cache entry to be nil") + } + + r.cache.entries["expired"] = cacheEntry{ + digest: "old", + expiresAt: time.Now().Add(-time.Minute), + } + if entry := r.getCached("expired"); entry != nil { + t.Fatalf("expected expired cache entry to be nil") + } + + r.CleanupCache() + if r.CacheSize() != 2 { + t.Fatalf("expected expired cache entry to be removed") + } +} + +func TestRegistryCheckImageUpdate(t *testing.T) { + ctx := context.Background() + + t.Run("DigestPinned", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + info, err := r.CheckImageUpdate(ctx, "nginx@sha256:abc", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.Error == "" { + t.Fatalf("expected digest-pinned error") + } + }) + + t.Run("CachedError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + key := "registry-1.docker.io/library/nginx:latest" + r.cache.entries[key] = cacheEntry{ + err: "cached error", + expiresAt: time.Now().Add(time.Hour), + } + + info, err := r.CheckImageUpdate(ctx, "nginx:latest", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.Error != "cached error" { + t.Fatalf("expected cached error") + } + }) + + t.Run("CachedDigest", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + key := "registry-1.docker.io/library/nginx:latest" + r.cache.entries[key] = cacheEntry{ + digest: "sha256:new", + expiresAt: time.Now().Add(time.Hour), + } + + info, err := r.CheckImageUpdate(ctx, "nginx:latest", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !info.UpdateAvailable || info.LatestDigest != "sha256:new" { + t.Fatalf("expected update available from cache") + } + }) + + t.Run("FetchErrorCaches", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusInternalServerError, nil, nil), nil + }), + } + + info, err := r.CheckImageUpdate(ctx, "example.com/repo:tag", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.Error == "" { + t.Fatalf("expected error from fetch") + } + if cached := r.getCached("example.com/repo:tag"); cached == nil || cached.err == "" { + t.Fatalf("expected error to be cached") + } + }) + + t.Run("FetchSuccessCaches", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + headers := http.Header{"Docker-Content-Digest": []string{"sha256:new"}} + return newResponse(http.StatusOK, headers, nil), nil + }), + } + + info, err := r.CheckImageUpdate(ctx, "example.com/repo:tag", "sha256:old") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !info.UpdateAvailable { + t.Fatalf("expected update available") + } + if cached := r.getCached("example.com/repo:tag"); cached == nil || cached.digest == "" { + t.Fatalf("expected digest to be cached") + } + }) +} + +func TestFetchDigest(t *testing.T) { + ctx := context.Background() + + t.Run("AuthError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("auth fail") + }), + } + + if _, err := r.fetchDigest(ctx, "registry-1.docker.io", "library/nginx", "latest"); err == nil { + t.Fatalf("expected auth error") + } + }) + + t.Run("InsecureScheme", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.AddRegistryConfig(RegistryConfig{Host: "insecure.local", Insecure: true}) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Scheme != "http" { + t.Fatalf("expected http scheme, got %q", req.URL.Scheme) + } + headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}} + return newResponse(http.StatusOK, headers, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "insecure.local", "repo", "latest"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("RequestError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + if _, err := r.fetchDigest(ctx, "bad host", "repo", "latest"); err == nil { + t.Fatalf("expected request creation error") + } + }) + + t.Run("TokenHeader", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.AddRegistryConfig(RegistryConfig{Host: "ghcr.io", Password: "pat"}) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Header.Get("Authorization"); got != "Bearer pat" { + t.Fatalf("expected bearer token, got %q", got) + } + headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}} + return newResponse(http.StatusOK, headers, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "ghcr.io", "owner/repo", "latest"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("DoError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network") + }), + } + + if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil { + t.Fatalf("expected request error") + } + }) + + t.Run("StatusUnauthorized", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusUnauthorized, nil, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil { + t.Fatalf("expected unauthorized error") + } + }) + + t.Run("StatusNotFound", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusNotFound, nil, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil { + t.Fatalf("expected not found error") + } + }) + + t.Run("StatusRateLimited", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusTooManyRequests, nil, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil { + t.Fatalf("expected rate limit error") + } + }) + + t.Run("StatusOtherError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusInternalServerError, nil, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil { + t.Fatalf("expected registry error") + } + }) + + t.Run("DigestHeader", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + headers := http.Header{"Docker-Content-Digest": []string{"sha256:abc"}} + return newResponse(http.StatusOK, headers, nil), nil + }), + } + + digest, err := r.fetchDigest(ctx, "example.com", "repo", "latest") + if err != nil || digest != "sha256:abc" { + t.Fatalf("expected digest, got %q err %v", digest, err) + } + }) + + t.Run("EtagHeader", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + headers := http.Header{"Etag": []string{`"sha256:etag"`}} + return newResponse(http.StatusOK, headers, nil), nil + }), + } + + digest, err := r.fetchDigest(ctx, "example.com", "repo", "latest") + if err != nil || digest != "sha256:etag" { + t.Fatalf("expected etag digest, got %q err %v", digest, err) + } + }) + + t.Run("NoDigest", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusOK, nil, nil), nil + }), + } + + if _, err := r.fetchDigest(ctx, "example.com", "repo", "latest"); err == nil { + t.Fatalf("expected missing digest error") + } + }) +} + +func TestGetAuthToken(t *testing.T) { + ctx := context.Background() + + t.Run("DockerHubSuccess", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.AddRegistryConfig(RegistryConfig{Host: "registry-1.docker.io", Username: "user", Password: "pass"}) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if !strings.HasPrefix(req.Header.Get("Authorization"), "Basic ") { + t.Fatalf("expected basic auth header") + } + body := `{"token":"tok"}` + return newResponse(http.StatusOK, nil, strings.NewReader(body)), nil + }), + } + + token, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx") + if err != nil || token != "tok" { + t.Fatalf("expected token, got %q err %v", token, err) + } + }) + + t.Run("DockerHubRequestError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "bad\nrepo"); err == nil { + t.Fatalf("expected request error") + } + }) + + t.Run("DockerHubDoError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network") + }), + } + + if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil { + t.Fatalf("expected network error") + } + }) + + t.Run("DockerHubStatusError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return newResponse(http.StatusInternalServerError, nil, nil), nil + }), + } + + if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil { + t.Fatalf("expected status error") + } + }) + + t.Run("DockerHubReadError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + body := &errorReader{err: errors.New("read")} + return newResponse(http.StatusOK, nil, body), nil + }), + } + + if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil { + t.Fatalf("expected read error") + } + }) + + t.Run("DockerHubJSONError", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.httpClient = &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + body := "{" + return newResponse(http.StatusOK, nil, strings.NewReader(body)), nil + }), + } + + if _, err := r.getAuthToken(ctx, "registry-1.docker.io", "library/nginx"); err == nil { + t.Fatalf("expected json error") + } + }) + + t.Run("GHCRToken", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.AddRegistryConfig(RegistryConfig{Host: "ghcr.io", Password: "pat"}) + token, err := r.getAuthToken(ctx, "ghcr.io", "owner/repo") + if err != nil || token != "pat" { + t.Fatalf("expected ghcr token") + } + }) + + t.Run("GHCRAnonymous", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + token, err := r.getAuthToken(ctx, "ghcr.io", "owner/repo") + if err != nil || token != "" { + t.Fatalf("expected empty ghcr token") + } + }) + + t.Run("OtherRegistryBasicAuth", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + r.AddRegistryConfig(RegistryConfig{Host: "example.com", Username: "user", Password: "pass"}) + token, err := r.getAuthToken(ctx, "example.com", "repo") + if err != nil || token != "" { + t.Fatalf("expected empty token for basic auth registry") + } + }) + + t.Run("OtherRegistryNoAuth", func(t *testing.T) { + r := NewRegistryChecker(zerolog.Nop()) + token, err := r.getAuthToken(ctx, "example.com", "repo") + if err != nil || token != "" { + t.Fatalf("expected empty token for registry without auth") + } + }) +} diff --git a/internal/updatedetection/store_coverage_test.go b/internal/updatedetection/store_coverage_test.go new file mode 100644 index 000000000..7509ba1f8 --- /dev/null +++ b/internal/updatedetection/store_coverage_test.go @@ -0,0 +1,43 @@ +package updatedetection + +import "testing" + +func TestStore_DeleteUpdateNotFound(t *testing.T) { + store := NewStore() + store.DeleteUpdate("missing") + if store.Count() != 0 { + t.Fatalf("expected empty store") + } +} + +func TestStore_DeleteUpdatesForResourceMissing(t *testing.T) { + store := NewStore() + store.DeleteUpdatesForResource("missing") + if store.Count() != 0 { + t.Fatalf("expected empty store") + } +} + +func TestStore_DeleteUpdatesForResourceNilUpdate(t *testing.T) { + store := NewStore() + store.byResource["res-1"] = "update-1" + store.updates["update-1"] = nil + + store.DeleteUpdatesForResource("res-1") + if _, ok := store.byResource["res-1"]; ok { + t.Fatalf("expected byResource entry to be removed") + } +} + +func TestStore_CountForHost(t *testing.T) { + store := NewStore() + store.UpsertUpdate(&UpdateInfo{ID: "update-1", ResourceID: "res-1", HostID: "host-1"}) + store.UpsertUpdate(&UpdateInfo{ID: "update-2", ResourceID: "res-2", HostID: "host-1"}) + + if store.CountForHost("host-1") != 2 { + t.Fatalf("expected count 2 for host-1") + } + if store.CountForHost("missing") != 0 { + t.Fatalf("expected count 0 for missing host") + } +}