beta
This commit is contained in:
parent
a45ddf13d5
commit
bcd986e3f7
46 changed files with 6166 additions and 454 deletions
353
middleware/middleware.go
Normal file
353
middleware/middleware.go
Normal file
|
@ -0,0 +1,353 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// Response represents a standard API response
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse sends a JSON error response
|
||||
func ErrorResponse(w http.ResponseWriter, code int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// SuccessResponse sends a JSON success response
|
||||
func SuccessResponse(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Code: http.StatusOK,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Logger is a middleware that logs request details with structured format
|
||||
func Logger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
r.Header.Set("X-Request-ID", requestID)
|
||||
}
|
||||
|
||||
// Create a custom response writer to capture status code
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Process request
|
||||
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
|
||||
|
||||
// Log structured request details
|
||||
log.Printf(
|
||||
"method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
rw.status,
|
||||
time.Since(start).String(),
|
||||
r.RemoteAddr,
|
||||
requestID,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// generateRequestID creates a unique request identifier
|
||||
func generateRequestID() string {
|
||||
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
|
||||
}
|
||||
|
||||
// randomString generates a random string of given length
|
||||
func randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Recover is a middleware that recovers from panics with detailed logging
|
||||
func Recover(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// Get request ID from context
|
||||
requestID := ""
|
||||
if ctx := r.Context(); ctx != nil {
|
||||
if id, ok := ctx.Value("request_id").(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Log error with stack trace and request context
|
||||
log.Printf(
|
||||
"panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
|
||||
err,
|
||||
requestID,
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
r.RemoteAddr,
|
||||
debug.Stack(),
|
||||
)
|
||||
|
||||
// Determine error type
|
||||
var message string
|
||||
var status int
|
||||
|
||||
switch e := err.(type) {
|
||||
case error:
|
||||
message = e.Error()
|
||||
if isClientError(e) {
|
||||
status = http.StatusBadRequest
|
||||
} else {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
case string:
|
||||
message = e
|
||||
status = http.StatusInternalServerError
|
||||
default:
|
||||
message = "Internal Server Error"
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
ErrorResponse(w, status, message)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// isClientError checks if error should be treated as client error
|
||||
func isClientError(err error) bool {
|
||||
// Add more client error types as needed
|
||||
return strings.Contains(err.Error(), "validation") ||
|
||||
strings.Contains(err.Error(), "invalid") ||
|
||||
strings.Contains(err.Error(), "missing")
|
||||
}
|
||||
|
||||
// CORS is a middleware that handles CORS
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Timeout is a middleware that safely handles request timeouts
|
||||
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), duration)
|
||||
defer cancel()
|
||||
|
||||
// Use buffered channels to prevent goroutine leaks
|
||||
done := make(chan struct{}, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
|
||||
// Track request processing in goroutine
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicChan <- p
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for completion, timeout or panic
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case p := <-panicChan:
|
||||
panic(p) // Re-throw panic to be caught by Recover middleware
|
||||
case <-ctx.Done():
|
||||
// Get request context for logging
|
||||
requestID := ""
|
||||
if ctx := r.Context(); ctx != nil {
|
||||
if id, ok := ctx.Value("request_id").(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Log timeout details
|
||||
log.Printf(
|
||||
"request_timeout: request_id=%s method=%s path=%s timeout=%s",
|
||||
requestID,
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
duration.String(),
|
||||
)
|
||||
|
||||
// Send timeout response
|
||||
ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
|
||||
"Request timed out after %s", duration.String(),
|
||||
))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Auth is a middleware that validates JWT tokens with enhanced security
|
||||
func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get request ID for logging
|
||||
requestID := ""
|
||||
if ctx := r.Context(); ctx != nil {
|
||||
if id, ok := ctx.Value("request_id").(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Get token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate header format
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer <token>'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Parse and validate token
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
|
||||
return
|
||||
}
|
||||
|
||||
// Check required claims
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
// Check token expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check required roles if specified
|
||||
if len(requiredRoles) > 0 {
|
||||
roles, ok := claims["roles"].([]interface{})
|
||||
if !ok {
|
||||
log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
|
||||
ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
|
||||
return
|
||||
}
|
||||
|
||||
hasRequiredRole := false
|
||||
for _, requiredRole := range requiredRoles {
|
||||
for _, role := range roles {
|
||||
if r, ok := role.(string); ok && r == requiredRole {
|
||||
hasRequiredRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRequiredRole {
|
||||
log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
|
||||
ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims to context
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, "user_id", userID)
|
||||
ctx = context.WithValue(ctx, "claims", claims)
|
||||
|
||||
log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// responseWriter is a custom response writer that captures the status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.status = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from the request context
|
||||
func GetUserID(r *http.Request) (string, error) {
|
||||
userID, ok := r.Context().Value("user_id").(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("user ID not found in context")
|
||||
}
|
||||
return userID, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue