Files
Pulse/internal/api/middleware.go
rcourtman 5c0d760d54 feat: improve request ID handling in middleware
Enhance request ID middleware to support distributed tracing:

- Honor incoming X-Request-ID headers from upstream proxies/load balancers
- Use logging.WithRequestID() for consistent ID generation across codebase
- Return X-Request-ID in response headers for client correlation
- Include request_id in panic recovery logs for debugging

This enables better request tracing across multiple Pulse instances
and integrates with standard distributed tracing practices.
2025-10-21 11:37:57 +00:00

195 lines
5.4 KiB
Go

package api
import (
"bufio"
"encoding/json"
"fmt"
"net"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/logging"
"github.com/rs/zerolog/log"
)
// APIError represents a structured API error response
type APIError struct {
ErrorMessage string `json:"error"`
Code string `json:"code,omitempty"`
StatusCode int `json:"status_code"`
Timestamp int64 `json:"timestamp"`
RequestID string `json:"request_id,omitempty"`
Details map[string]string `json:"details,omitempty"`
}
// Error implements the error interface
func (e *APIError) Error() string {
return e.ErrorMessage
}
// ErrorHandler is a middleware that handles panics and errors
func ErrorHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Fix for issue #334: Normalize empty path to "/" before ServeMux processes it
// This prevents the automatic redirect from "" to "./"
if r.URL.Path == "" {
r.URL.Path = "/"
}
// Skip error handling for WebSocket endpoints
if r.Header.Get("Upgrade") == "websocket" {
next.ServeHTTP(w, r)
return
}
// Add request ID to context, honoring any incoming header value.
incomingID := strings.TrimSpace(r.Header.Get("X-Request-ID"))
ctxWithID, requestID := logging.WithRequestID(r.Context(), incomingID)
r = r.WithContext(ctxWithID)
// Create a custom response writer to capture status codes
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
rw.Header().Set("X-Request-ID", requestID)
// Recover from panics
defer func() {
if err := recover(); err != nil {
log.Error().
Interface("error", err).
Str("path", r.URL.Path).
Str("method", r.Method).
Str("request_id", requestID).
Bytes("stack", debug.Stack()).
Msg("Panic recovered in API handler")
writeErrorResponse(w, http.StatusInternalServerError, "internal_error",
"An unexpected error occurred", nil)
}
}()
// Call the next handler
next.ServeHTTP(rw, r)
// Log errors (4xx and 5xx)
if rw.statusCode >= 400 {
log.Warn().
Str("path", r.URL.Path).
Str("method", r.Method).
Int("status", rw.statusCode).
Str("request_id", requestID).
Msg("Request failed")
}
})
}
// TimeoutHandler wraps handlers with a timeout
func TimeoutHandler(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip timeout for WebSocket and SSE endpoints
if r.Header.Get("Upgrade") == "websocket" || r.Header.Get("Accept") == "text/event-stream" {
next.ServeHTTP(w, r)
return
}
http.TimeoutHandler(next, timeout, "Request timeout").ServeHTTP(w, r)
})
}
}
// JSONHandler ensures proper JSON responses and error handling
func JSONHandler(handler func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := handler(w, r); err != nil {
// Check if it's already an APIError
if apiErr, ok := err.(*APIError); ok {
writeErrorResponse(w, apiErr.StatusCode, apiErr.Code, apiErr.ErrorMessage, apiErr.Details)
return
}
// Generic error
log.Error().Err(err).
Str("path", r.URL.Path).
Str("method", r.Method).
Msg("Handler error")
writeErrorResponse(w, http.StatusInternalServerError, "internal_error",
"An error occurred processing the request", nil)
}
}
}
// writeErrorResponse writes a consistent error response
func writeErrorResponse(w http.ResponseWriter, statusCode int, code, message string, details map[string]string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
resp := APIError{
ErrorMessage: message,
Code: code,
StatusCode: statusCode,
Timestamp: time.Now().Unix(),
Details: details,
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Error().Err(err).Msg("Failed to encode error response")
}
}
// responseWriter wraps http.ResponseWriter to capture status codes
type responseWriter struct {
http.ResponseWriter
statusCode int
written bool
}
func (rw *responseWriter) WriteHeader(code int) {
if !rw.written {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
rw.written = true
}
}
func (rw *responseWriter) Write(b []byte) (int, error) {
if !rw.written {
rw.WriteHeader(http.StatusOK)
}
return rw.ResponseWriter.Write(b)
}
// Hijack implements http.Hijacker interface
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("ResponseWriter does not implement http.Hijacker")
}
return hijacker.Hijack()
}
// NewAPIError creates a new API error
func NewAPIError(statusCode int, code, message string) error {
return &APIError{
ErrorMessage: message,
Code: code,
StatusCode: statusCode,
Timestamp: time.Now().Unix(),
}
}
// ValidationError creates a validation error with field details
func ValidationError(fields map[string]string) error {
return &APIError{
ErrorMessage: "Validation failed",
Code: "validation_error",
StatusCode: http.StatusBadRequest,
Timestamp: time.Now().Unix(),
Details: fields,
}
}