332 lines
8.2 KiB
Go
332 lines
8.2 KiB
Go
package security
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/argon2"
|
|
)
|
|
|
|
// SecurityService provides security functionality
|
|
type SecurityService struct {
|
|
config *Config
|
|
}
|
|
|
|
// Config represents security configuration
|
|
type Config struct {
|
|
// CSRF protection
|
|
CSRFTokenLength int
|
|
CSRFTokenExpiry time.Duration
|
|
CSRFCookieName string
|
|
CSRFHeaderName string
|
|
CSRFCookieSecure bool
|
|
CSRFCookieHTTPOnly bool
|
|
CSRFCookieSameSite http.SameSite
|
|
|
|
// Rate limiting
|
|
RateLimitRequests int
|
|
RateLimitWindow time.Duration
|
|
|
|
// Password hashing
|
|
Argon2Time uint32
|
|
Argon2Memory uint32
|
|
Argon2Threads uint8
|
|
Argon2KeyLen uint32
|
|
Argon2SaltLen uint32
|
|
}
|
|
|
|
// DefaultConfig returns the default security configuration
|
|
func DefaultConfig() *Config {
|
|
return &Config{
|
|
// CSRF protection
|
|
CSRFTokenLength: 32,
|
|
CSRFTokenExpiry: 24 * time.Hour,
|
|
CSRFCookieName: "csrf_token",
|
|
CSRFHeaderName: "X-CSRF-Token",
|
|
CSRFCookieSecure: true,
|
|
CSRFCookieHTTPOnly: true,
|
|
CSRFCookieSameSite: http.SameSiteStrictMode,
|
|
|
|
// Rate limiting
|
|
RateLimitRequests: 100,
|
|
RateLimitWindow: time.Minute,
|
|
|
|
// Password hashing
|
|
Argon2Time: 1,
|
|
Argon2Memory: 64 * 1024,
|
|
Argon2Threads: 4,
|
|
Argon2KeyLen: 32,
|
|
Argon2SaltLen: 16,
|
|
}
|
|
}
|
|
|
|
// NewSecurityService creates a new SecurityService
|
|
func NewSecurityService(config *Config) *SecurityService {
|
|
if config == nil {
|
|
config = DefaultConfig()
|
|
}
|
|
return &SecurityService{
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// GenerateCSRFToken generates a CSRF token
|
|
func (s *SecurityService) GenerateCSRFToken() (string, error) {
|
|
// Generate random bytes
|
|
bytes := make([]byte, s.config.CSRFTokenLength)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
|
}
|
|
|
|
// Encode as base64
|
|
token := base64.StdEncoding.EncodeToString(bytes)
|
|
return token, nil
|
|
}
|
|
|
|
// SetCSRFCookie sets a CSRF cookie
|
|
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: s.config.CSRFCookieName,
|
|
Value: token,
|
|
Path: "/",
|
|
Expires: time.Now().Add(s.config.CSRFTokenExpiry),
|
|
Secure: s.config.CSRFCookieSecure,
|
|
HttpOnly: s.config.CSRFCookieHTTPOnly,
|
|
SameSite: s.config.CSRFCookieSameSite,
|
|
})
|
|
}
|
|
|
|
// ValidateCSRFToken validates a CSRF token
|
|
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
|
|
// Get token from cookie
|
|
cookie, err := r.Cookie(s.config.CSRFCookieName)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
cookieToken := cookie.Value
|
|
|
|
// Get token from header
|
|
headerToken := r.Header.Get(s.config.CSRFHeaderName)
|
|
|
|
// Compare tokens
|
|
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
|
|
}
|
|
|
|
// CSRFMiddleware creates a middleware that validates CSRF tokens
|
|
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip for GET, HEAD, OPTIONS, TRACE
|
|
if r.Method == http.MethodGet ||
|
|
r.Method == http.MethodHead ||
|
|
r.Method == http.MethodOptions ||
|
|
r.Method == http.MethodTrace {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Validate CSRF token
|
|
if !s.ValidateCSRFToken(r) {
|
|
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// RateLimiter represents a rate limiter
|
|
type RateLimiter struct {
|
|
requests map[string][]time.Time
|
|
config *Config
|
|
}
|
|
|
|
// NewRateLimiter creates a new RateLimiter
|
|
func NewRateLimiter(config *Config) *RateLimiter {
|
|
return &RateLimiter{
|
|
requests: make(map[string][]time.Time),
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request is allowed
|
|
func (r *RateLimiter) Allow(key string) bool {
|
|
now := time.Now()
|
|
windowStart := now.Add(-r.config.RateLimitWindow)
|
|
|
|
// Get requests for key
|
|
requests := r.requests[key]
|
|
|
|
// Filter out old requests
|
|
var newRequests []time.Time
|
|
for _, t := range requests {
|
|
if t.After(windowStart) {
|
|
newRequests = append(newRequests, t)
|
|
}
|
|
}
|
|
|
|
// Check if rate limit is exceeded
|
|
if len(newRequests) >= r.config.RateLimitRequests {
|
|
return false
|
|
}
|
|
|
|
// Add current request
|
|
newRequests = append(newRequests, now)
|
|
r.requests[key] = newRequests
|
|
|
|
return true
|
|
}
|
|
|
|
// RateLimitMiddleware creates a middleware that limits request rate
|
|
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get client IP
|
|
ip := getClientIP(r)
|
|
|
|
// Check if request is allowed
|
|
if !limiter.Allow(ip) {
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// getClientIP gets the client IP address
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header
|
|
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
|
if xForwardedFor != "" {
|
|
// X-Forwarded-For can contain multiple IPs, use the first one
|
|
ips := strings.Split(xForwardedFor, ",")
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
xRealIP := r.Header.Get("X-Real-IP")
|
|
if xRealIP != "" {
|
|
return xRealIP
|
|
}
|
|
|
|
// Use RemoteAddr
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// HashPassword hashes a password using Argon2
|
|
func (s *SecurityService) HashPassword(password string) (string, error) {
|
|
// Generate salt
|
|
salt := make([]byte, s.config.Argon2SaltLen)
|
|
if _, err := rand.Read(salt); err != nil {
|
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
|
}
|
|
|
|
// Hash password
|
|
hash := argon2.IDKey(
|
|
[]byte(password),
|
|
salt,
|
|
s.config.Argon2Time,
|
|
s.config.Argon2Memory,
|
|
s.config.Argon2Threads,
|
|
s.config.Argon2KeyLen,
|
|
)
|
|
|
|
// Encode as base64
|
|
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
|
hashBase64 := base64.StdEncoding.EncodeToString(hash)
|
|
|
|
// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
|
|
return fmt.Sprintf(
|
|
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
|
s.config.Argon2Memory,
|
|
s.config.Argon2Time,
|
|
s.config.Argon2Threads,
|
|
saltBase64,
|
|
hashBase64,
|
|
), nil
|
|
}
|
|
|
|
// VerifyPassword verifies a password against a hash
|
|
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
|
|
// Parse encoded hash
|
|
parts := strings.Split(encodedHash, "$")
|
|
if len(parts) != 6 {
|
|
return false, fmt.Errorf("invalid hash format")
|
|
}
|
|
|
|
// Extract parameters
|
|
if parts[1] != "argon2id" {
|
|
return false, fmt.Errorf("unsupported hash algorithm")
|
|
}
|
|
|
|
var memory uint32
|
|
var time uint32
|
|
var threads uint8
|
|
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to parse hash parameters: %w", err)
|
|
}
|
|
|
|
// Decode salt and hash
|
|
salt, err := base64.StdEncoding.DecodeString(parts[4])
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to decode salt: %w", err)
|
|
}
|
|
|
|
hash, err := base64.StdEncoding.DecodeString(parts[5])
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to decode hash: %w", err)
|
|
}
|
|
|
|
// Hash password with same parameters
|
|
newHash := argon2.IDKey(
|
|
[]byte(password),
|
|
salt,
|
|
time,
|
|
memory,
|
|
threads,
|
|
uint32(len(hash)),
|
|
)
|
|
|
|
// Compare hashes
|
|
return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
|
|
}
|
|
|
|
// SecureHeadersMiddleware adds security headers to responses
|
|
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Add security headers
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// contextKey is a type for context keys
|
|
type contextKey string
|
|
|
|
// userIDKey is the key for user ID in context
|
|
const userIDKey = contextKey("user_id")
|
|
|
|
// WithUserID adds a user ID to the context
|
|
func WithUserID(ctx context.Context, userID string) context.Context {
|
|
return context.WithValue(ctx, userIDKey, userID)
|
|
}
|
|
|
|
// GetUserID gets the user ID from the context
|
|
func GetUserID(ctx context.Context) (string, bool) {
|
|
userID, ok := ctx.Value(userIDKey).(string)
|
|
return userID, ok
|
|
}
|