package security import ( "context" "crypto/rand" "crypto/subtle" "encoding/base64" "fmt" "net/http" "strings" "time" "golang.org/x/crypto/argon2" ) // SecurityService provides security functionality type SecurityService struct { config *Config } // Config represents security configuration type Config struct { // CSRF protection CSRFTokenLength int CSRFTokenExpiry time.Duration CSRFCookieName string CSRFHeaderName string CSRFCookieSecure bool CSRFCookieHTTPOnly bool CSRFCookieSameSite http.SameSite // Rate limiting RateLimitRequests int RateLimitWindow time.Duration // Password hashing Argon2Time uint32 Argon2Memory uint32 Argon2Threads uint8 Argon2KeyLen uint32 Argon2SaltLen uint32 } // DefaultConfig returns the default security configuration func DefaultConfig() *Config { return &Config{ // CSRF protection CSRFTokenLength: 32, CSRFTokenExpiry: 24 * time.Hour, CSRFCookieName: "csrf_token", CSRFHeaderName: "X-CSRF-Token", CSRFCookieSecure: true, CSRFCookieHTTPOnly: true, CSRFCookieSameSite: http.SameSiteStrictMode, // Rate limiting RateLimitRequests: 100, RateLimitWindow: time.Minute, // Password hashing Argon2Time: 1, Argon2Memory: 64 * 1024, Argon2Threads: 4, Argon2KeyLen: 32, Argon2SaltLen: 16, } } // NewSecurityService creates a new SecurityService func NewSecurityService(config *Config) *SecurityService { if config == nil { config = DefaultConfig() } return &SecurityService{ config: config, } } // GenerateCSRFToken generates a CSRF token func (s *SecurityService) GenerateCSRFToken() (string, error) { // Generate random bytes bytes := make([]byte, s.config.CSRFTokenLength) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("failed to generate random bytes: %w", err) } // Encode as base64 token := base64.StdEncoding.EncodeToString(bytes) return token, nil } // SetCSRFCookie sets a CSRF cookie func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) { http.SetCookie(w, &http.Cookie{ Name: s.config.CSRFCookieName, Value: token, Path: "/", Expires: time.Now().Add(s.config.CSRFTokenExpiry), Secure: s.config.CSRFCookieSecure, HttpOnly: s.config.CSRFCookieHTTPOnly, SameSite: s.config.CSRFCookieSameSite, }) } // ValidateCSRFToken validates a CSRF token func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool { // Get token from cookie cookie, err := r.Cookie(s.config.CSRFCookieName) if err != nil { return false } cookieToken := cookie.Value // Get token from header headerToken := r.Header.Get(s.config.CSRFHeaderName) // Compare tokens return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1 } // CSRFMiddleware creates a middleware that validates CSRF tokens func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip for GET, HEAD, OPTIONS, TRACE if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions || r.Method == http.MethodTrace { next.ServeHTTP(w, r) return } // Validate CSRF token if !s.ValidateCSRFToken(r) { http.Error(w, "Invalid CSRF token", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } // RateLimiter represents a rate limiter type RateLimiter struct { requests map[string][]time.Time config *Config } // NewRateLimiter creates a new RateLimiter func NewRateLimiter(config *Config) *RateLimiter { return &RateLimiter{ requests: make(map[string][]time.Time), config: config, } } // Allow checks if a request is allowed func (r *RateLimiter) Allow(key string) bool { now := time.Now() windowStart := now.Add(-r.config.RateLimitWindow) // Get requests for key requests := r.requests[key] // Filter out old requests var newRequests []time.Time for _, t := range requests { if t.After(windowStart) { newRequests = append(newRequests, t) } } // Check if rate limit is exceeded if len(newRequests) >= r.config.RateLimitRequests { return false } // Add current request newRequests = append(newRequests, now) r.requests[key] = newRequests return true } // RateLimitMiddleware creates a middleware that limits request rate func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get client IP ip := getClientIP(r) // Check if request is allowed if !limiter.Allow(ip) { http.Error(w, "Too many requests", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } } // getClientIP gets the client IP address func getClientIP(r *http.Request) string { // Check X-Forwarded-For header xForwardedFor := r.Header.Get("X-Forwarded-For") if xForwardedFor != "" { // X-Forwarded-For can contain multiple IPs, use the first one ips := strings.Split(xForwardedFor, ",") return strings.TrimSpace(ips[0]) } // Check X-Real-IP header xRealIP := r.Header.Get("X-Real-IP") if xRealIP != "" { return xRealIP } // Use RemoteAddr return r.RemoteAddr } // HashPassword hashes a password using Argon2 func (s *SecurityService) HashPassword(password string) (string, error) { // Generate salt salt := make([]byte, s.config.Argon2SaltLen) if _, err := rand.Read(salt); err != nil { return "", fmt.Errorf("failed to generate salt: %w", err) } // Hash password hash := argon2.IDKey( []byte(password), salt, s.config.Argon2Time, s.config.Argon2Memory, s.config.Argon2Threads, s.config.Argon2KeyLen, ) // Encode as base64 saltBase64 := base64.StdEncoding.EncodeToString(salt) hashBase64 := base64.StdEncoding.EncodeToString(hash) // Format as $argon2id$v=19$m=65536,t=1,p=4$$ return fmt.Sprintf( "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", s.config.Argon2Memory, s.config.Argon2Time, s.config.Argon2Threads, saltBase64, hashBase64, ), nil } // VerifyPassword verifies a password against a hash func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) { // Parse encoded hash parts := strings.Split(encodedHash, "$") if len(parts) != 6 { return false, fmt.Errorf("invalid hash format") } // Extract parameters if parts[1] != "argon2id" { return false, fmt.Errorf("unsupported hash algorithm") } var memory uint32 var time uint32 var threads uint8 _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) if err != nil { return false, fmt.Errorf("failed to parse hash parameters: %w", err) } // Decode salt and hash salt, err := base64.StdEncoding.DecodeString(parts[4]) if err != nil { return false, fmt.Errorf("failed to decode salt: %w", err) } hash, err := base64.StdEncoding.DecodeString(parts[5]) if err != nil { return false, fmt.Errorf("failed to decode hash: %w", err) } // Hash password with same parameters newHash := argon2.IDKey( []byte(password), salt, time, memory, threads, uint32(len(hash)), ) // Compare hashes return subtle.ConstantTimeCompare(hash, newHash) == 1, nil } // SecureHeadersMiddleware adds security headers to responses func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Add security headers w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Content-Security-Policy", "default-src 'self'") w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") next.ServeHTTP(w, r) }) } // contextKey is a type for context keys type contextKey string // userIDKey is the key for user ID in context const userIDKey = contextKey("user_id") // WithUserID adds a user ID to the context func WithUserID(ctx context.Context, userID string) context.Context { return context.WithValue(ctx, userIDKey, userID) } // GetUserID gets the user ID from the context func GetUserID(ctx context.Context) (string, bool) { userID, ok := ctx.Value(userIDKey).(string) return userID, ok }