316 lines
8.8 KiB
Go
316 lines
8.8 KiB
Go
package validator
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"reflect"
|
||
"regexp"
|
||
"strings"
|
||
|
||
"github.com/go-playground/validator/v10"
|
||
)
|
||
|
||
var (
|
||
validate *validator.Validate
|
||
|
||
// 自定义验证规则
|
||
customValidations = map[string]validator.Func{
|
||
"otpsecret": validateOTPSecret,
|
||
"password": validatePassword,
|
||
"issuer": validateIssuer,
|
||
"otpauth_uri": validateOTPAuthURI,
|
||
"no_xss": validateNoXSS,
|
||
}
|
||
|
||
// 常见的弱密码列表(实际使用时应该使用更完整的列表)
|
||
commonPasswords = map[string]bool{
|
||
"password123": true,
|
||
"12345678": true,
|
||
"qwerty123": true,
|
||
"admin123": true,
|
||
"letmein": true,
|
||
"welcome": true,
|
||
"password": true,
|
||
"admin": true,
|
||
}
|
||
|
||
// 预编译的XSS检测正则表达式
|
||
xssPatterns = []*regexp.Regexp{
|
||
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
|
||
regexp.MustCompile(`(?i)javascript:`),
|
||
regexp.MustCompile(`(?i)data:text/html`),
|
||
regexp.MustCompile(`(?i)on\w+\s*=`),
|
||
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
|
||
regexp.MustCompile(`(?i)<\s*iframe`),
|
||
regexp.MustCompile(`(?i)<\s*object`),
|
||
regexp.MustCompile(`(?i)<\s*embed`),
|
||
regexp.MustCompile(`(?i)<\s*style`),
|
||
regexp.MustCompile(`(?i)<\s*form`),
|
||
regexp.MustCompile(`(?i)<\s*applet`),
|
||
regexp.MustCompile(`(?i)<\s*meta`),
|
||
regexp.MustCompile(`(?i)expression\s*\(`),
|
||
regexp.MustCompile(`(?i)url\s*\(`),
|
||
}
|
||
|
||
// 预编译的正则表达式
|
||
base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
|
||
issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
|
||
otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
|
||
upperRegex = regexp.MustCompile(`[A-Z]`)
|
||
lowerRegex = regexp.MustCompile(`[a-z]`)
|
||
numberRegex = regexp.MustCompile(`[0-9]`)
|
||
specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
|
||
)
|
||
|
||
func init() {
|
||
validate = validator.New()
|
||
|
||
// 注册自定义验证规则
|
||
for tag, fn := range customValidations {
|
||
if err := validate.RegisterValidation(tag, fn); err != nil {
|
||
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
|
||
}
|
||
}
|
||
|
||
// 使用json tag作为字段名
|
||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||
if name == "-" {
|
||
return ""
|
||
}
|
||
return name
|
||
})
|
||
}
|
||
|
||
// ValidateRequest validates a request body against a struct
|
||
func ValidateRequest(r *http.Request, v interface{}) error {
|
||
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
|
||
return fmt.Errorf("invalid request body: %w", err)
|
||
}
|
||
|
||
if err := validate.Struct(v); err != nil {
|
||
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
||
return NewValidationError(validationErrors)
|
||
}
|
||
return fmt.Errorf("validation error: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ValidationError represents a validation error
|
||
type ValidationError struct {
|
||
Fields map[string]string `json:"fields"`
|
||
}
|
||
|
||
// Error implements the error interface
|
||
func (e *ValidationError) Error() string {
|
||
var errors []string
|
||
for field, msg := range e.Fields {
|
||
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
|
||
}
|
||
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
|
||
}
|
||
|
||
// NewValidationError creates a new ValidationError from validator.ValidationErrors
|
||
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
|
||
fields := make(map[string]string)
|
||
for _, err := range errors {
|
||
fields[err.Field()] = getErrorMessage(err)
|
||
}
|
||
return &ValidationError{Fields: fields}
|
||
}
|
||
|
||
// getErrorMessage returns a human-readable error message for a validation error
|
||
func getErrorMessage(err validator.FieldError) string {
|
||
switch err.Tag() {
|
||
case "required":
|
||
return "此字段为必填项"
|
||
case "email":
|
||
return "请输入有效的电子邮件地址"
|
||
case "min":
|
||
if err.Type().Kind() == reflect.String {
|
||
return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
|
||
}
|
||
return fmt.Sprintf("必须大于或等于 %s", err.Param())
|
||
case "max":
|
||
if err.Type().Kind() == reflect.String {
|
||
return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
|
||
}
|
||
return fmt.Sprintf("必须小于或等于 %s", err.Param())
|
||
case "len":
|
||
return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
|
||
case "oneof":
|
||
return fmt.Sprintf("必须是以下值之一: %s", err.Param())
|
||
case "otpsecret":
|
||
return "OTP密钥格式无效,必须是有效的Base32编码"
|
||
case "password":
|
||
return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符"
|
||
case "issuer":
|
||
return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
|
||
case "otpauth_uri":
|
||
return "OTP认证URI格式无效"
|
||
case "no_xss":
|
||
return "输入包含潜在的不安全内容"
|
||
case "numeric":
|
||
return "必须是数字"
|
||
default:
|
||
return fmt.Sprintf("验证失败: %s", err.Tag())
|
||
}
|
||
}
|
||
|
||
// Custom validation functions
|
||
|
||
// validateOTPSecret validates an OTP secret
|
||
func validateOTPSecret(fl validator.FieldLevel) bool {
|
||
secret := fl.Field().String()
|
||
|
||
if secret == "" {
|
||
return false
|
||
}
|
||
|
||
// OTP secret should be base32 encoded
|
||
if !base32Regex.MatchString(secret) {
|
||
return false
|
||
}
|
||
|
||
// Check length (typical OTP secrets are 16-64 characters)
|
||
validLength := len(secret) >= 16 && len(secret) <= 128
|
||
|
||
return validLength
|
||
}
|
||
|
||
// validatePassword validates a password
|
||
func validatePassword(fl validator.FieldLevel) bool {
|
||
password := fl.Field().String()
|
||
|
||
// At least 10 characters long
|
||
if len(password) < 10 {
|
||
return false
|
||
}
|
||
|
||
// Check if it's a common password
|
||
if commonPasswords[strings.ToLower(password)] {
|
||
return false
|
||
}
|
||
|
||
// Check character types
|
||
hasUpper := upperRegex.MatchString(password)
|
||
hasLower := lowerRegex.MatchString(password)
|
||
hasNumber := numberRegex.MatchString(password)
|
||
hasSpecial := specialRegex.MatchString(password)
|
||
|
||
// Ensure password has enough complexity
|
||
complexity := 0
|
||
if hasUpper {
|
||
complexity++
|
||
}
|
||
if hasLower {
|
||
complexity++
|
||
}
|
||
if hasNumber {
|
||
complexity++
|
||
}
|
||
if hasSpecial {
|
||
complexity++
|
||
}
|
||
|
||
return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
|
||
}
|
||
|
||
// validateIssuer validates an issuer name
|
||
func validateIssuer(fl validator.FieldLevel) bool {
|
||
issuer := fl.Field().String()
|
||
|
||
if issuer == "" {
|
||
return false
|
||
}
|
||
|
||
// Issuer should not contain special characters that could cause problems in URLs
|
||
if !issuerRegex.MatchString(issuer) {
|
||
return false
|
||
}
|
||
|
||
// Check length
|
||
validLength := len(issuer) >= 1 && len(issuer) <= 100
|
||
|
||
return validLength
|
||
}
|
||
|
||
// validateOTPAuthURI validates an otpauth URI
|
||
func validateOTPAuthURI(fl validator.FieldLevel) bool {
|
||
uri := fl.Field().String()
|
||
|
||
if uri == "" {
|
||
return false
|
||
}
|
||
|
||
// Basic format check for otpauth URI
|
||
// Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
|
||
return otpauthRegex.MatchString(uri)
|
||
}
|
||
|
||
// validateNoXSS checks if a string contains potential XSS payloads
|
||
func validateNoXSS(fl validator.FieldLevel) bool {
|
||
value := fl.Field().String()
|
||
|
||
// 检查基本的HTML编码
|
||
if strings.Contains(value, "&#") ||
|
||
strings.Contains(value, "<") ||
|
||
strings.Contains(value, ">") {
|
||
return false
|
||
}
|
||
|
||
// 检查十六进制编码
|
||
if strings.Contains(strings.ToLower(value), "\\x3c") || // <
|
||
strings.Contains(strings.ToLower(value), "\\x3e") { // >
|
||
return false
|
||
}
|
||
|
||
// 检查Unicode编码
|
||
if strings.Contains(strings.ToLower(value), "\\u003c") || // <
|
||
strings.Contains(strings.ToLower(value), "\\u003e") { // >
|
||
return false
|
||
}
|
||
|
||
// 使用预编译的正则表达式检查XSS模式
|
||
for _, pattern := range xssPatterns {
|
||
if pattern.MatchString(value) {
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// Request validation structs
|
||
|
||
// LoginRequest represents a login request
|
||
type LoginRequest struct {
|
||
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
|
||
}
|
||
|
||
// CreateOTPRequest represents a request to create an OTP
|
||
type CreateOTPRequest struct {
|
||
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
|
||
Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
|
||
Secret string `json:"secret" validate:"required,otpsecret"`
|
||
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
|
||
Digits int `json:"digits" validate:"required,oneof=6 8"`
|
||
Period int `json:"period" validate:"required,oneof=30 60"`
|
||
}
|
||
|
||
// UpdateOTPRequest represents a request to update an OTP
|
||
type UpdateOTPRequest struct {
|
||
Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
|
||
Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
|
||
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
|
||
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
|
||
Period int `json:"period" validate:"omitempty,oneof=30 60"`
|
||
}
|
||
|
||
// VerifyOTPRequest represents a request to verify an OTP code
|
||
type VerifyOTPRequest struct {
|
||
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
|
||
}
|