package middleware import ( "context" "encoding/json" "fmt" "log" "math/rand" "net/http" "runtime/debug" "strings" "time" "github.com/golang-jwt/jwt" ) // Response represents a standard API response type Response struct { Code int `json:"code"` Message string `json:"message"` Data interface{} `json:"data,omitempty"` } // ErrorResponse sends a JSON error response func ErrorResponse(w http.ResponseWriter, code int, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) json.NewEncoder(w).Encode(Response{ Code: code, Message: message, }) } // SuccessResponse sends a JSON success response func SuccessResponse(w http.ResponseWriter, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(Response{ Code: http.StatusOK, Message: "success", Data: data, }) } // Logger is a middleware that logs request details with structured format func Logger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() requestID := r.Header.Get("X-Request-ID") if requestID == "" { requestID = generateRequestID() r.Header.Set("X-Request-ID", requestID) } // Create a custom response writer to capture status code rw := &responseWriter{ ResponseWriter: w, status: http.StatusOK, } // Process request next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID))) // Log structured request details log.Printf( "method=%s path=%s status=%d duration=%s ip=%s request_id=%s", r.Method, r.URL.Path, rw.status, time.Since(start).String(), r.RemoteAddr, requestID, ) }) } // generateRequestID creates a unique request identifier func generateRequestID() string { return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) } // randomString generates a random string of given length func randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, length) for i := range b { b[i] = charset[rand.Intn(len(charset))] } return string(b) } // Recover is a middleware that recovers from panics with detailed logging func Recover(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { // Get request ID from context requestID := "" if ctx := r.Context(); ctx != nil { if id, ok := ctx.Value("request_id").(string); ok { requestID = id } } // Log error with stack trace and request context log.Printf( "panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s", err, requestID, r.Method, r.URL.Path, r.RemoteAddr, debug.Stack(), ) // Determine error type var message string var status int switch e := err.(type) { case error: message = e.Error() if isClientError(e) { status = http.StatusBadRequest } else { status = http.StatusInternalServerError } case string: message = e status = http.StatusInternalServerError default: message = "Internal Server Error" status = http.StatusInternalServerError } ErrorResponse(w, status, message) } }() next.ServeHTTP(w, r) }) } // isClientError checks if error should be treated as client error func isClientError(err error) bool { // Add more client error types as needed return strings.Contains(err.Error(), "validation") || strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "missing") } // CORS is a middleware that handles CORS func CORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } // Timeout is a middleware that safely handles request timeouts func Timeout(duration time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), duration) defer cancel() // Use buffered channels to prevent goroutine leaks done := make(chan struct{}, 1) panicChan := make(chan interface{}, 1) // Track request processing in goroutine go func() { defer func() { if p := recover(); p != nil { panicChan <- p } }() next.ServeHTTP(w, r.WithContext(ctx)) done <- struct{}{} }() // Wait for completion, timeout or panic select { case <-done: return case p := <-panicChan: panic(p) // Re-throw panic to be caught by Recover middleware case <-ctx.Done(): // Get request context for logging requestID := "" if ctx := r.Context(); ctx != nil { if id, ok := ctx.Value("request_id").(string); ok { requestID = id } } // Log timeout details log.Printf( "request_timeout: request_id=%s method=%s path=%s timeout=%s", requestID, r.Method, r.URL.Path, duration.String(), ) // Send timeout response ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf( "Request timed out after %s", duration.String(), )) } }) } } // Auth is a middleware that validates JWT tokens with enhanced security func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get request ID for logging requestID := "" if ctx := r.Context(); ctx != nil { if id, ok := ctx.Value("request_id").(string); ok { requestID = id } } // Get token from Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID) ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required") return } // Validate header format parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID) ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer '") return } tokenString := parts[1] // Parse and validate token token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Validate signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(jwtSecret), nil }) if err != nil { log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err) ErrorResponse(w, http.StatusUnauthorized, "Invalid token") return } if !token.Valid { log.Printf("auth_failed: request_id=%s error=invalid_token", requestID) ErrorResponse(w, http.StatusUnauthorized, "Invalid token") return } // Validate claims claims, ok := token.Claims.(jwt.MapClaims) if !ok { log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID) ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims") return } // Check required claims userID, ok := claims["user_id"].(string) if !ok || userID == "" { log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID) ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token") return } // Check token expiration if exp, ok := claims["exp"].(float64); ok { if time.Now().Unix() > int64(exp) { log.Printf("auth_failed: request_id=%s error=token_expired", requestID) ErrorResponse(w, http.StatusUnauthorized, "Token has expired") return } } // Check required roles if specified if len(requiredRoles) > 0 { roles, ok := claims["roles"].([]interface{}) if !ok { log.Printf("auth_failed: request_id=%s error=missing_roles", requestID) ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles") return } hasRequiredRole := false for _, requiredRole := range requiredRoles { for _, role := range roles { if r, ok := role.(string); ok && r == requiredRole { hasRequiredRole = true break } } } if !hasRequiredRole { log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID) ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions") return } } // Add claims to context ctx := r.Context() ctx = context.WithValue(ctx, "user_id", userID) ctx = context.WithValue(ctx, "claims", claims) log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // responseWriter is a custom response writer that captures the status code type responseWriter struct { http.ResponseWriter status int } func (rw *responseWriter) WriteHeader(code int) { rw.status = code rw.ResponseWriter.WriteHeader(code) } // GetUserID gets the user ID from the request context func GetUserID(r *http.Request) (string, error) { userID, ok := r.Context().Value("user_id").(string) if !ok { return "", fmt.Errorf("user ID not found in context") } return userID, nil }