package security

import (
	"context"
	"crypto/rand"
	"crypto/subtle"
	"encoding/base64"
	"fmt"
	"net/http"
	"strings"
	"time"

	"golang.org/x/crypto/argon2"
)

// SecurityService provides security functionality
type SecurityService struct {
	config *Config
}

// Config represents security configuration
type Config struct {
	// CSRF protection
	CSRFTokenLength    int
	CSRFTokenExpiry    time.Duration
	CSRFCookieName     string
	CSRFHeaderName     string
	CSRFCookieSecure   bool
	CSRFCookieHTTPOnly bool
	CSRFCookieSameSite http.SameSite

	// Rate limiting
	RateLimitRequests int
	RateLimitWindow   time.Duration

	// Password hashing
	Argon2Time    uint32
	Argon2Memory  uint32
	Argon2Threads uint8
	Argon2KeyLen  uint32
	Argon2SaltLen uint32
}

// DefaultConfig returns the default security configuration
func DefaultConfig() *Config {
	return &Config{
		// CSRF protection
		CSRFTokenLength:    32,
		CSRFTokenExpiry:    24 * time.Hour,
		CSRFCookieName:     "csrf_token",
		CSRFHeaderName:     "X-CSRF-Token",
		CSRFCookieSecure:   true,
		CSRFCookieHTTPOnly: true,
		CSRFCookieSameSite: http.SameSiteStrictMode,

		// Rate limiting
		RateLimitRequests: 100,
		RateLimitWindow:   time.Minute,

		// Password hashing
		Argon2Time:    1,
		Argon2Memory:  64 * 1024,
		Argon2Threads: 4,
		Argon2KeyLen:  32,
		Argon2SaltLen: 16,
	}
}

// NewSecurityService creates a new SecurityService
func NewSecurityService(config *Config) *SecurityService {
	if config == nil {
		config = DefaultConfig()
	}
	return &SecurityService{
		config: config,
	}
}

// GenerateCSRFToken generates a CSRF token
func (s *SecurityService) GenerateCSRFToken() (string, error) {
	// Generate random bytes
	bytes := make([]byte, s.config.CSRFTokenLength)
	if _, err := rand.Read(bytes); err != nil {
		return "", fmt.Errorf("failed to generate random bytes: %w", err)
	}

	// Encode as base64
	token := base64.StdEncoding.EncodeToString(bytes)
	return token, nil
}

// SetCSRFCookie sets a CSRF cookie
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
	http.SetCookie(w, &http.Cookie{
		Name:     s.config.CSRFCookieName,
		Value:    token,
		Path:     "/",
		Expires:  time.Now().Add(s.config.CSRFTokenExpiry),
		Secure:   s.config.CSRFCookieSecure,
		HttpOnly: s.config.CSRFCookieHTTPOnly,
		SameSite: s.config.CSRFCookieSameSite,
	})
}

// ValidateCSRFToken validates a CSRF token
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
	// Get token from cookie
	cookie, err := r.Cookie(s.config.CSRFCookieName)
	if err != nil {
		return false
	}
	cookieToken := cookie.Value

	// Get token from header
	headerToken := r.Header.Get(s.config.CSRFHeaderName)

	// Compare tokens
	return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
}

// CSRFMiddleware creates a middleware that validates CSRF tokens
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Skip for GET, HEAD, OPTIONS, TRACE
		if r.Method == http.MethodGet ||
			r.Method == http.MethodHead ||
			r.Method == http.MethodOptions ||
			r.Method == http.MethodTrace {
			next.ServeHTTP(w, r)
			return
		}

		// Validate CSRF token
		if !s.ValidateCSRFToken(r) {
			http.Error(w, "Invalid CSRF token", http.StatusForbidden)
			return
		}

		next.ServeHTTP(w, r)
	})
}

// RateLimiter represents a rate limiter
type RateLimiter struct {
	requests map[string][]time.Time
	config   *Config
}

// NewRateLimiter creates a new RateLimiter
func NewRateLimiter(config *Config) *RateLimiter {
	return &RateLimiter{
		requests: make(map[string][]time.Time),
		config:   config,
	}
}

// Allow checks if a request is allowed
func (r *RateLimiter) Allow(key string) bool {
	now := time.Now()
	windowStart := now.Add(-r.config.RateLimitWindow)

	// Get requests for key
	requests := r.requests[key]

	// Filter out old requests
	var newRequests []time.Time
	for _, t := range requests {
		if t.After(windowStart) {
			newRequests = append(newRequests, t)
		}
	}

	// Check if rate limit is exceeded
	if len(newRequests) >= r.config.RateLimitRequests {
		return false
	}

	// Add current request
	newRequests = append(newRequests, now)
	r.requests[key] = newRequests

	return true
}

// RateLimitMiddleware creates a middleware that limits request rate
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			// Get client IP
			ip := getClientIP(r)

			// Check if request is allowed
			if !limiter.Allow(ip) {
				http.Error(w, "Too many requests", http.StatusTooManyRequests)
				return
			}

			next.ServeHTTP(w, r)
		})
	}
}

// getClientIP gets the client IP address
func getClientIP(r *http.Request) string {
	// Check X-Forwarded-For header
	xForwardedFor := r.Header.Get("X-Forwarded-For")
	if xForwardedFor != "" {
		// X-Forwarded-For can contain multiple IPs, use the first one
		ips := strings.Split(xForwardedFor, ",")
		return strings.TrimSpace(ips[0])
	}

	// Check X-Real-IP header
	xRealIP := r.Header.Get("X-Real-IP")
	if xRealIP != "" {
		return xRealIP
	}

	// Use RemoteAddr
	return r.RemoteAddr
}

// HashPassword hashes a password using Argon2
func (s *SecurityService) HashPassword(password string) (string, error) {
	// Generate salt
	salt := make([]byte, s.config.Argon2SaltLen)
	if _, err := rand.Read(salt); err != nil {
		return "", fmt.Errorf("failed to generate salt: %w", err)
	}

	// Hash password
	hash := argon2.IDKey(
		[]byte(password),
		salt,
		s.config.Argon2Time,
		s.config.Argon2Memory,
		s.config.Argon2Threads,
		s.config.Argon2KeyLen,
	)

	// Encode as base64
	saltBase64 := base64.StdEncoding.EncodeToString(salt)
	hashBase64 := base64.StdEncoding.EncodeToString(hash)

	// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
	return fmt.Sprintf(
		"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
		s.config.Argon2Memory,
		s.config.Argon2Time,
		s.config.Argon2Threads,
		saltBase64,
		hashBase64,
	), nil
}

// VerifyPassword verifies a password against a hash
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
	// Parse encoded hash
	parts := strings.Split(encodedHash, "$")
	if len(parts) != 6 {
		return false, fmt.Errorf("invalid hash format")
	}

	// Extract parameters
	if parts[1] != "argon2id" {
		return false, fmt.Errorf("unsupported hash algorithm")
	}

	var memory uint32
	var time uint32
	var threads uint8
	_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
	if err != nil {
		return false, fmt.Errorf("failed to parse hash parameters: %w", err)
	}

	// Decode salt and hash
	salt, err := base64.StdEncoding.DecodeString(parts[4])
	if err != nil {
		return false, fmt.Errorf("failed to decode salt: %w", err)
	}

	hash, err := base64.StdEncoding.DecodeString(parts[5])
	if err != nil {
		return false, fmt.Errorf("failed to decode hash: %w", err)
	}

	// Hash password with same parameters
	newHash := argon2.IDKey(
		[]byte(password),
		salt,
		time,
		memory,
		threads,
		uint32(len(hash)),
	)

	// Compare hashes
	return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
}

// SecureHeadersMiddleware adds security headers to responses
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Add security headers
		w.Header().Set("X-Content-Type-Options", "nosniff")
		w.Header().Set("X-Frame-Options", "DENY")
		w.Header().Set("X-XSS-Protection", "1; mode=block")
		w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
		w.Header().Set("Content-Security-Policy", "default-src 'self'")
		w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")

		next.ServeHTTP(w, r)
	})
}

// contextKey is a type for context keys
type contextKey string

// userIDKey is the key for user ID in context
const userIDKey = contextKey("user_id")

// WithUserID adds a user ID to the context
func WithUserID(ctx context.Context, userID string) context.Context {
	return context.WithValue(ctx, userIDKey, userID)
}

// GetUserID gets the user ID from the context
func GetUserID(ctx context.Context) (string, bool) {
	userID, ok := ctx.Value(userIDKey).(string)
	return userID, ok
}