otpm/security/security.go
“xHuPo” bcd986e3f7 beta
2025-05-23 18:57:11 +08:00

332 lines
8.2 KiB
Go

package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/crypto/argon2"
)
// SecurityService provides security functionality
type SecurityService struct {
config *Config
}
// Config represents security configuration
type Config struct {
// CSRF protection
CSRFTokenLength int
CSRFTokenExpiry time.Duration
CSRFCookieName string
CSRFHeaderName string
CSRFCookieSecure bool
CSRFCookieHTTPOnly bool
CSRFCookieSameSite http.SameSite
// Rate limiting
RateLimitRequests int
RateLimitWindow time.Duration
// Password hashing
Argon2Time uint32
Argon2Memory uint32
Argon2Threads uint8
Argon2KeyLen uint32
Argon2SaltLen uint32
}
// DefaultConfig returns the default security configuration
func DefaultConfig() *Config {
return &Config{
// CSRF protection
CSRFTokenLength: 32,
CSRFTokenExpiry: 24 * time.Hour,
CSRFCookieName: "csrf_token",
CSRFHeaderName: "X-CSRF-Token",
CSRFCookieSecure: true,
CSRFCookieHTTPOnly: true,
CSRFCookieSameSite: http.SameSiteStrictMode,
// Rate limiting
RateLimitRequests: 100,
RateLimitWindow: time.Minute,
// Password hashing
Argon2Time: 1,
Argon2Memory: 64 * 1024,
Argon2Threads: 4,
Argon2KeyLen: 32,
Argon2SaltLen: 16,
}
}
// NewSecurityService creates a new SecurityService
func NewSecurityService(config *Config) *SecurityService {
if config == nil {
config = DefaultConfig()
}
return &SecurityService{
config: config,
}
}
// GenerateCSRFToken generates a CSRF token
func (s *SecurityService) GenerateCSRFToken() (string, error) {
// Generate random bytes
bytes := make([]byte, s.config.CSRFTokenLength)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode as base64
token := base64.StdEncoding.EncodeToString(bytes)
return token, nil
}
// SetCSRFCookie sets a CSRF cookie
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
http.SetCookie(w, &http.Cookie{
Name: s.config.CSRFCookieName,
Value: token,
Path: "/",
Expires: time.Now().Add(s.config.CSRFTokenExpiry),
Secure: s.config.CSRFCookieSecure,
HttpOnly: s.config.CSRFCookieHTTPOnly,
SameSite: s.config.CSRFCookieSameSite,
})
}
// ValidateCSRFToken validates a CSRF token
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
// Get token from cookie
cookie, err := r.Cookie(s.config.CSRFCookieName)
if err != nil {
return false
}
cookieToken := cookie.Value
// Get token from header
headerToken := r.Header.Get(s.config.CSRFHeaderName)
// Compare tokens
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
}
// CSRFMiddleware creates a middleware that validates CSRF tokens
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip for GET, HEAD, OPTIONS, TRACE
if r.Method == http.MethodGet ||
r.Method == http.MethodHead ||
r.Method == http.MethodOptions ||
r.Method == http.MethodTrace {
next.ServeHTTP(w, r)
return
}
// Validate CSRF token
if !s.ValidateCSRFToken(r) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// RateLimiter represents a rate limiter
type RateLimiter struct {
requests map[string][]time.Time
config *Config
}
// NewRateLimiter creates a new RateLimiter
func NewRateLimiter(config *Config) *RateLimiter {
return &RateLimiter{
requests: make(map[string][]time.Time),
config: config,
}
}
// Allow checks if a request is allowed
func (r *RateLimiter) Allow(key string) bool {
now := time.Now()
windowStart := now.Add(-r.config.RateLimitWindow)
// Get requests for key
requests := r.requests[key]
// Filter out old requests
var newRequests []time.Time
for _, t := range requests {
if t.After(windowStart) {
newRequests = append(newRequests, t)
}
}
// Check if rate limit is exceeded
if len(newRequests) >= r.config.RateLimitRequests {
return false
}
// Add current request
newRequests = append(newRequests, now)
r.requests[key] = newRequests
return true
}
// RateLimitMiddleware creates a middleware that limits request rate
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP
ip := getClientIP(r)
// Check if request is allowed
if !limiter.Allow(ip) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// getClientIP gets the client IP address
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, use the first one
ips := strings.Split(xForwardedFor, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xRealIP := r.Header.Get("X-Real-IP")
if xRealIP != "" {
return xRealIP
}
// Use RemoteAddr
return r.RemoteAddr
}
// HashPassword hashes a password using Argon2
func (s *SecurityService) HashPassword(password string) (string, error) {
// Generate salt
salt := make([]byte, s.config.Argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
// Hash password
hash := argon2.IDKey(
[]byte(password),
salt,
s.config.Argon2Time,
s.config.Argon2Memory,
s.config.Argon2Threads,
s.config.Argon2KeyLen,
)
// Encode as base64
saltBase64 := base64.StdEncoding.EncodeToString(salt)
hashBase64 := base64.StdEncoding.EncodeToString(hash)
// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
return fmt.Sprintf(
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
s.config.Argon2Memory,
s.config.Argon2Time,
s.config.Argon2Threads,
saltBase64,
hashBase64,
), nil
}
// VerifyPassword verifies a password against a hash
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
// Parse encoded hash
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
// Extract parameters
if parts[1] != "argon2id" {
return false, fmt.Errorf("unsupported hash algorithm")
}
var memory uint32
var time uint32
var threads uint8
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
if err != nil {
return false, fmt.Errorf("failed to parse hash parameters: %w", err)
}
// Decode salt and hash
salt, err := base64.StdEncoding.DecodeString(parts[4])
if err != nil {
return false, fmt.Errorf("failed to decode salt: %w", err)
}
hash, err := base64.StdEncoding.DecodeString(parts[5])
if err != nil {
return false, fmt.Errorf("failed to decode hash: %w", err)
}
// Hash password with same parameters
newHash := argon2.IDKey(
[]byte(password),
salt,
time,
memory,
threads,
uint32(len(hash)),
)
// Compare hashes
return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
}
// SecureHeadersMiddleware adds security headers to responses
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
next.ServeHTTP(w, r)
})
}
// contextKey is a type for context keys
type contextKey string
// userIDKey is the key for user ID in context
const userIDKey = contextKey("user_id")
// WithUserID adds a user ID to the context
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}
// GetUserID gets the user ID from the context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userIDKey).(string)
return userID, ok
}