package validator import ( "encoding/json" "fmt" "net/http" "reflect" "regexp" "strings" "github.com/go-playground/validator/v10" ) var ( validate *validator.Validate // 自定义验证规则 customValidations = map[string]validator.Func{ "otpsecret": validateOTPSecret, "password": validatePassword, "issuer": validateIssuer, "otpauth_uri": validateOTPAuthURI, "no_xss": validateNoXSS, } // 常见的弱密码列表(实际使用时应该使用更完整的列表) commonPasswords = map[string]bool{ "password123": true, "12345678": true, "qwerty123": true, "admin123": true, "letmein": true, "welcome": true, "password": true, "admin": true, } // 预编译的XSS检测正则表达式 xssPatterns = []*regexp.Regexp{ regexp.MustCompile(`(?i)]*>.*?`), regexp.MustCompile(`(?i)javascript:`), regexp.MustCompile(`(?i)data:text/html`), regexp.MustCompile(`(?i)on\w+\s*=`), regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), regexp.MustCompile(`(?i)<\s*iframe`), regexp.MustCompile(`(?i)<\s*object`), regexp.MustCompile(`(?i)<\s*embed`), regexp.MustCompile(`(?i)<\s*style`), regexp.MustCompile(`(?i)<\s*form`), regexp.MustCompile(`(?i)<\s*applet`), regexp.MustCompile(`(?i)<\s*meta`), regexp.MustCompile(`(?i)expression\s*\(`), regexp.MustCompile(`(?i)url\s*\(`), } // 预编译的正则表达式 base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`) issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`) otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`) upperRegex = regexp.MustCompile(`[A-Z]`) lowerRegex = regexp.MustCompile(`[a-z]`) numberRegex = regexp.MustCompile(`[0-9]`) specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`) ) func init() { validate = validator.New() // 注册自定义验证规则 for tag, fn := range customValidations { if err := validate.RegisterValidation(tag, fn); err != nil { panic(fmt.Sprintf("failed to register validation %s: %v", tag, err)) } } // 使用json tag作为字段名 validate.RegisterTagNameFunc(func(fld reflect.StructField) string { name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] if name == "-" { return "" } return name }) } // ValidateRequest validates a request body against a struct func ValidateRequest(r *http.Request, v interface{}) error { if err := json.NewDecoder(r.Body).Decode(v); err != nil { return fmt.Errorf("invalid request body: %w", err) } if err := validate.Struct(v); err != nil { if validationErrors, ok := err.(validator.ValidationErrors); ok { return NewValidationError(validationErrors) } return fmt.Errorf("validation error: %w", err) } return nil } // ValidationError represents a validation error type ValidationError struct { Fields map[string]string `json:"fields"` } // Error implements the error interface func (e *ValidationError) Error() string { var errors []string for field, msg := range e.Fields { errors = append(errors, fmt.Sprintf("%s: %s", field, msg)) } return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; ")) } // NewValidationError creates a new ValidationError from validator.ValidationErrors func NewValidationError(errors validator.ValidationErrors) *ValidationError { fields := make(map[string]string) for _, err := range errors { fields[err.Field()] = getErrorMessage(err) } return &ValidationError{Fields: fields} } // getErrorMessage returns a human-readable error message for a validation error func getErrorMessage(err validator.FieldError) string { switch err.Tag() { case "required": return "此字段为必填项" case "email": return "请输入有效的电子邮件地址" case "min": if err.Type().Kind() == reflect.String { return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param()) } return fmt.Sprintf("必须大于或等于 %s", err.Param()) case "max": if err.Type().Kind() == reflect.String { return fmt.Sprintf("长度不能超过 %s 个字符", err.Param()) } return fmt.Sprintf("必须小于或等于 %s", err.Param()) case "len": return fmt.Sprintf("长度必须为 %s 个字符", err.Param()) case "oneof": return fmt.Sprintf("必须是以下值之一: %s", err.Param()) case "otpsecret": return "OTP密钥格式无效,必须是有效的Base32编码" case "password": return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符" case "issuer": return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号" case "otpauth_uri": return "OTP认证URI格式无效" case "no_xss": return "输入包含潜在的不安全内容" case "numeric": return "必须是数字" default: return fmt.Sprintf("验证失败: %s", err.Tag()) } } // Custom validation functions // validateOTPSecret validates an OTP secret func validateOTPSecret(fl validator.FieldLevel) bool { secret := fl.Field().String() if secret == "" { return false } // OTP secret should be base32 encoded if !base32Regex.MatchString(secret) { return false } // Check length (typical OTP secrets are 16-64 characters) validLength := len(secret) >= 16 && len(secret) <= 128 return validLength } // validatePassword validates a password func validatePassword(fl validator.FieldLevel) bool { password := fl.Field().String() // At least 10 characters long if len(password) < 10 { return false } // Check if it's a common password if commonPasswords[strings.ToLower(password)] { return false } // Check character types hasUpper := upperRegex.MatchString(password) hasLower := lowerRegex.MatchString(password) hasNumber := numberRegex.MatchString(password) hasSpecial := specialRegex.MatchString(password) // Ensure password has enough complexity complexity := 0 if hasUpper { complexity++ } if hasLower { complexity++ } if hasNumber { complexity++ } if hasSpecial { complexity++ } return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial) } // validateIssuer validates an issuer name func validateIssuer(fl validator.FieldLevel) bool { issuer := fl.Field().String() if issuer == "" { return false } // Issuer should not contain special characters that could cause problems in URLs if !issuerRegex.MatchString(issuer) { return false } // Check length validLength := len(issuer) >= 1 && len(issuer) <= 100 return validLength } // validateOTPAuthURI validates an otpauth URI func validateOTPAuthURI(fl validator.FieldLevel) bool { uri := fl.Field().String() if uri == "" { return false } // Basic format check for otpauth URI // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD return otpauthRegex.MatchString(uri) } // validateNoXSS checks if a string contains potential XSS payloads func validateNoXSS(fl validator.FieldLevel) bool { value := fl.Field().String() // 检查基本的HTML编码 if strings.Contains(value, "&#") || strings.Contains(value, "<") || strings.Contains(value, ">") { return false } // 检查十六进制编码 if strings.Contains(strings.ToLower(value), "\\x3c") || // < strings.Contains(strings.ToLower(value), "\\x3e") { // > return false } // 检查Unicode编码 if strings.Contains(strings.ToLower(value), "\\u003c") || // < strings.Contains(strings.ToLower(value), "\\u003e") { // > return false } // 使用预编译的正则表达式检查XSS模式 for _, pattern := range xssPatterns { if pattern.MatchString(value) { return false } } return true } // Request validation structs // LoginRequest represents a login request type LoginRequest struct { Code string `json:"code" validate:"required,len=6|len=8,numeric"` } // CreateOTPRequest represents a request to create an OTP type CreateOTPRequest struct { Name string `json:"name" validate:"required,min=1,max=100,no_xss"` Issuer string `json:"issuer" validate:"required,issuer,no_xss"` Secret string `json:"secret" validate:"required,otpsecret"` Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` Digits int `json:"digits" validate:"required,oneof=6 8"` Period int `json:"period" validate:"required,oneof=30 60"` } // UpdateOTPRequest represents a request to update an OTP type UpdateOTPRequest struct { Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"` Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"` Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` Digits int `json:"digits" validate:"omitempty,oneof=6 8"` Period int `json:"period" validate:"omitempty,oneof=30 60"` } // VerifyOTPRequest represents a request to verify an OTP code type VerifyOTPRequest struct { Code string `json:"code" validate:"required,len=6|len=8,numeric"` }