package handlers

import (
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"time"

	"otpm/api"
	"otpm/services"

	"github.com/golang-jwt/jwt"
)

// AuthHandler handles authentication related requests
type AuthHandler struct {
	authService *services.AuthService
}

// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
	return &AuthHandler{
		authService: authService,
	}
}

// LoginRequest represents a login request
type LoginRequest struct {
	Code string `json:"code"`
}

// LoginResponse represents a login response
type LoginResponse struct {
	Token  string `json:"token"`
	OpenID string `json:"openid"`
}

// Login handles WeChat login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
	start := time.Now()

	// Limit request body size to prevent DOS
	r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request

	// Parse request
	var req LoginRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
			fmt.Sprintf("Invalid request body: %v", err))
		log.Printf("Login request parse error: %v", err)
		return
	}

	// Validate request
	if req.Code == "" {
		api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
			"Code is required")
		log.Printf("Login request validation failed: empty code")
		return
	}

	// Login with WeChat code
	token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
	if err != nil {
		api.NewResponseWriter(w).WriteError(api.InternalError(err))
		log.Printf("Login failed for code %s: %v", req.Code, err)
		return
	}

	// Log successful login
	log.Printf("Login successful for code %s (took %v)",
		req.Code, time.Since(start))

	// Return token
	api.NewResponseWriter(w).WriteSuccess(LoginResponse{
		Token: token,
	})
}

// VerifyToken handles token verification
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
	start := time.Now()

	// Get token from Authorization header
	authHeader := r.Header.Get("Authorization")
	if authHeader == "" {
		api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
			"Authorization header is required")
		log.Printf("Token verification failed: missing Authorization header")
		return
	}

	// Validate token format
	if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
		api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
			"Invalid token format. Expected 'Bearer <token>'")
		log.Printf("Token verification failed: invalid token format")
		return
	}

	token := authHeader[7:]
	if len(token) < 32 { // Basic length check
		api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
			"Invalid token length")
		log.Printf("Token verification failed: token too short")
		return
	}

	// Validate token
	claims, err := h.authService.ValidateToken(token)
	if err != nil {
		api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
		log.Printf("Token verification failed for token %s: %v",
			maskToken(token), err) // Mask token in logs
		return
	}

	// Log successful verification
	userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
	if !ok {
		log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
	} else {
		log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
	}

	// Token is valid
	api.NewResponseWriter(w).WriteSuccess(map[string]bool{
		"valid": true,
	})
}

// maskToken masks sensitive parts of token for logging
func maskToken(token string) string {
	if len(token) < 8 {
		return "****"
	}
	return token[:4] + "****" + token[len(token)-4:]
}

// Routes returns all routes for the auth handler
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
	return map[string]http.HandlerFunc{
		"/login":        h.Login,
		"/verify-token": h.VerifyToken,
	}
}