otpm/middleware/middleware.go
“xHuPo” bcd986e3f7 beta
2025-05-23 18:57:11 +08:00

353 lines
9.6 KiB
Go

package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/golang-jwt/jwt"
)
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ErrorResponse sends a JSON error response
func ErrorResponse(w http.ResponseWriter, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(Response{
Code: code,
Message: message,
})
}
// SuccessResponse sends a JSON success response
func SuccessResponse(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
Code: http.StatusOK,
Message: "success",
Data: data,
})
}
// Logger is a middleware that logs request details with structured format
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
r.Header.Set("X-Request-ID", requestID)
}
// Create a custom response writer to capture status code
rw := &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
}
// Process request
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
// Log structured request details
log.Printf(
"method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
r.Method,
r.URL.Path,
rw.status,
time.Since(start).String(),
r.RemoteAddr,
requestID,
)
})
}
// generateRequestID creates a unique request identifier
func generateRequestID() string {
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}
// randomString generates a random string of given length
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// Recover is a middleware that recovers from panics with detailed logging
func Recover(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// Get request ID from context
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log error with stack trace and request context
log.Printf(
"panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
err,
requestID,
r.Method,
r.URL.Path,
r.RemoteAddr,
debug.Stack(),
)
// Determine error type
var message string
var status int
switch e := err.(type) {
case error:
message = e.Error()
if isClientError(e) {
status = http.StatusBadRequest
} else {
status = http.StatusInternalServerError
}
case string:
message = e
status = http.StatusInternalServerError
default:
message = "Internal Server Error"
status = http.StatusInternalServerError
}
ErrorResponse(w, status, message)
}
}()
next.ServeHTTP(w, r)
})
}
// isClientError checks if error should be treated as client error
func isClientError(err error) bool {
// Add more client error types as needed
return strings.Contains(err.Error(), "validation") ||
strings.Contains(err.Error(), "invalid") ||
strings.Contains(err.Error(), "missing")
}
// CORS is a middleware that handles CORS
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// Timeout is a middleware that safely handles request timeouts
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), duration)
defer cancel()
// Use buffered channels to prevent goroutine leaks
done := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
// Track request processing in goroutine
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(w, r.WithContext(ctx))
done <- struct{}{}
}()
// Wait for completion, timeout or panic
select {
case <-done:
return
case p := <-panicChan:
panic(p) // Re-throw panic to be caught by Recover middleware
case <-ctx.Done():
// Get request context for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log timeout details
log.Printf(
"request_timeout: request_id=%s method=%s path=%s timeout=%s",
requestID,
r.Method,
r.URL.Path,
duration.String(),
)
// Send timeout response
ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
"Request timed out after %s", duration.String(),
))
}
})
}
}
// Auth is a middleware that validates JWT tokens with enhanced security
func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get request ID for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
return
}
// Validate header format
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer <token>'")
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtSecret), nil
})
if err != nil {
log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
if !token.Valid {
log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
// Validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
return
}
// Check required claims
userID, ok := claims["user_id"].(string)
if !ok || userID == "" {
log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
return
}
// Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
return
}
}
// Check required roles if specified
if len(requiredRoles) > 0 {
roles, ok := claims["roles"].([]interface{})
if !ok {
log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
return
}
hasRequiredRole := false
for _, requiredRole := range requiredRoles {
for _, role := range roles {
if r, ok := role.(string); ok && r == requiredRole {
hasRequiredRole = true
break
}
}
}
if !hasRequiredRole {
log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
return
}
}
// Add claims to context
ctx := r.Context()
ctx = context.WithValue(ctx, "user_id", userID)
ctx = context.WithValue(ctx, "claims", claims)
log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// responseWriter is a custom response writer that captures the status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// GetUserID gets the user ID from the request context
func GetUserID(r *http.Request) (string, error) {
userID, ok := r.Context().Value("user_id").(string)
if !ok {
return "", fmt.Errorf("user ID not found in context")
}
return userID, nil
}