otpm/validator/validator.go
“xHuPo” 5d370e1077 error
2025-05-27 17:44:24 +08:00

316 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package validator
import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var (
validate *validator.Validate
// 自定义验证规则
customValidations = map[string]validator.Func{
"otpsecret": validateOTPSecret,
"password": validatePassword,
"issuer": validateIssuer,
"otpauth_uri": validateOTPAuthURI,
"no_xss": validateNoXSS,
}
// 常见的弱密码列表(实际使用时应该使用更完整的列表)
commonPasswords = map[string]bool{
"password123": true,
"12345678": true,
"qwerty123": true,
"admin123": true,
"letmein": true,
"welcome": true,
"password": true,
"admin": true,
}
// 预编译的XSS检测正则表达式
xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
}
// 预编译的正则表达式
base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
upperRegex = regexp.MustCompile(`[A-Z]`)
lowerRegex = regexp.MustCompile(`[a-z]`)
numberRegex = regexp.MustCompile(`[0-9]`)
specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
)
func init() {
validate = validator.New()
// 注册自定义验证规则
for tag, fn := range customValidations {
if err := validate.RegisterValidation(tag, fn); err != nil {
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
}
}
// 使用json tag作为字段名
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// ValidateRequest validates a request body against a struct
func ValidateRequest(r *http.Request, v interface{}) error {
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := validate.Struct(v); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
return NewValidationError(validationErrors)
}
return fmt.Errorf("validation error: %w", err)
}
return nil
}
// ValidationError represents a validation error
type ValidationError struct {
Fields map[string]string `json:"fields"`
}
// Error implements the error interface
func (e *ValidationError) Error() string {
var errors []string
for field, msg := range e.Fields {
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
}
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
}
// NewValidationError creates a new ValidationError from validator.ValidationErrors
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
fields := make(map[string]string)
for _, err := range errors {
fields[err.Field()] = getErrorMessage(err)
}
return &ValidationError{Fields: fields}
}
// getErrorMessage returns a human-readable error message for a validation error
func getErrorMessage(err validator.FieldError) string {
switch err.Tag() {
case "required":
return "此字段为必填项"
case "email":
return "请输入有效的电子邮件地址"
case "min":
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
}
return fmt.Sprintf("必须大于或等于 %s", err.Param())
case "max":
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
}
return fmt.Sprintf("必须小于或等于 %s", err.Param())
case "len":
return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
case "oneof":
return fmt.Sprintf("必须是以下值之一: %s", err.Param())
case "otpsecret":
return "OTP密钥格式无效必须是有效的Base32编码"
case "password":
return "密码必须至少10个字符并包含大写字母、小写字母以及数字或特殊字符"
case "issuer":
return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
case "otpauth_uri":
return "OTP认证URI格式无效"
case "no_xss":
return "输入包含潜在的不安全内容"
case "numeric":
return "必须是数字"
default:
return fmt.Sprintf("验证失败: %s", err.Tag())
}
}
// Custom validation functions
// validateOTPSecret validates an OTP secret
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
if secret == "" {
return false
}
// OTP secret should be base32 encoded
if !base32Regex.MatchString(secret) {
return false
}
// Check length (typical OTP secrets are 16-64 characters)
validLength := len(secret) >= 16 && len(secret) <= 128
return validLength
}
// validatePassword validates a password
func validatePassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
// At least 10 characters long
if len(password) < 10 {
return false
}
// Check if it's a common password
if commonPasswords[strings.ToLower(password)] {
return false
}
// Check character types
hasUpper := upperRegex.MatchString(password)
hasLower := lowerRegex.MatchString(password)
hasNumber := numberRegex.MatchString(password)
hasSpecial := specialRegex.MatchString(password)
// Ensure password has enough complexity
complexity := 0
if hasUpper {
complexity++
}
if hasLower {
complexity++
}
if hasNumber {
complexity++
}
if hasSpecial {
complexity++
}
return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
}
// validateIssuer validates an issuer name
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
if issuer == "" {
return false
}
// Issuer should not contain special characters that could cause problems in URLs
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
validLength := len(issuer) >= 1 && len(issuer) <= 100
return validLength
}
// validateOTPAuthURI validates an otpauth URI
func validateOTPAuthURI(fl validator.FieldLevel) bool {
uri := fl.Field().String()
if uri == "" {
return false
}
// Basic format check for otpauth URI
// Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
return otpauthRegex.MatchString(uri)
}
// validateNoXSS checks if a string contains potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// 检查基本的HTML编码
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// 检查十六进制编码
if strings.Contains(strings.ToLower(value), "\\x3c") || // <
strings.Contains(strings.ToLower(value), "\\x3e") { // >
return false
}
// 检查Unicode编码
if strings.Contains(strings.ToLower(value), "\\u003c") || // <
strings.Contains(strings.ToLower(value), "\\u003e") { // >
return false
}
// 使用预编译的正则表达式检查XSS模式
for _, pattern := range xssPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}
// Request validation structs
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"required,oneof=6 8"`
Period int `json:"period" validate:"required,oneof=30 60"`
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
Period int `json:"period" validate:"omitempty,oneof=30 60"`
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}