159 lines
4.7 KiB
Go
159 lines
4.7 KiB
Go
package validator
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/go-playground/validator/v10"
|
|
)
|
|
|
|
var (
|
|
validate *validator.Validate
|
|
// 自定义验证规则
|
|
customValidations = map[string]validator.Func{
|
|
"otpsecret": validateOTPSecret,
|
|
"password": validatePassword,
|
|
}
|
|
)
|
|
|
|
func init() {
|
|
validate = validator.New()
|
|
|
|
// 注册自定义验证规则
|
|
for tag, fn := range customValidations {
|
|
if err := validate.RegisterValidation(tag, fn); err != nil {
|
|
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
|
|
}
|
|
}
|
|
|
|
// 使用json tag作为字段名
|
|
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
|
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
|
if name == "-" {
|
|
return ""
|
|
}
|
|
return name
|
|
})
|
|
}
|
|
|
|
// ValidateRequest validates a request body against a struct
|
|
func ValidateRequest(r *http.Request, v interface{}) error {
|
|
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
|
|
return fmt.Errorf("invalid request body: %w", err)
|
|
}
|
|
|
|
if err := validate.Struct(v); err != nil {
|
|
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
|
return NewValidationError(validationErrors)
|
|
}
|
|
return fmt.Errorf("validation error: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidationError represents a validation error
|
|
type ValidationError struct {
|
|
Fields map[string]string `json:"fields"`
|
|
}
|
|
|
|
// Error implements the error interface
|
|
func (e *ValidationError) Error() string {
|
|
var errors []string
|
|
for field, msg := range e.Fields {
|
|
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
|
|
}
|
|
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
|
|
}
|
|
|
|
// NewValidationError creates a new ValidationError from validator.ValidationErrors
|
|
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
|
|
fields := make(map[string]string)
|
|
for _, err := range errors {
|
|
fields[err.Field()] = getErrorMessage(err)
|
|
}
|
|
return &ValidationError{Fields: fields}
|
|
}
|
|
|
|
// getErrorMessage returns a human-readable error message for a validation error
|
|
func getErrorMessage(err validator.FieldError) string {
|
|
switch err.Tag() {
|
|
case "required":
|
|
return "This field is required"
|
|
case "email":
|
|
return "Invalid email address"
|
|
case "min":
|
|
return fmt.Sprintf("Must be at least %s characters long", err.Param())
|
|
case "max":
|
|
return fmt.Sprintf("Must be at most %s characters long", err.Param())
|
|
case "otpsecret":
|
|
return "Invalid OTP secret format"
|
|
case "password":
|
|
return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character"
|
|
default:
|
|
return fmt.Sprintf("Failed validation on tag: %s", err.Tag())
|
|
}
|
|
}
|
|
|
|
// Custom validation functions
|
|
|
|
// validateOTPSecret validates an OTP secret
|
|
func validateOTPSecret(fl validator.FieldLevel) bool {
|
|
secret := fl.Field().String()
|
|
// OTP secret should be base32 encoded
|
|
matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret)
|
|
return matched
|
|
}
|
|
|
|
// validatePassword validates a password
|
|
func validatePassword(fl validator.FieldLevel) bool {
|
|
password := fl.Field().String()
|
|
// At least 8 characters long
|
|
if len(password) < 8 {
|
|
return false
|
|
}
|
|
|
|
var (
|
|
hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password)
|
|
hasLower = regexp.MustCompile(`[a-z]`).MatchString(password)
|
|
hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password)
|
|
hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
|
|
)
|
|
|
|
return hasUpper && hasLower && hasNumber && hasSpecial
|
|
}
|
|
|
|
// Request validation structs
|
|
|
|
// LoginRequest represents a login request
|
|
type LoginRequest struct {
|
|
Code string `json:"code" validate:"required"`
|
|
}
|
|
|
|
// CreateOTPRequest represents a request to create an OTP
|
|
type CreateOTPRequest struct {
|
|
Name string `json:"name" validate:"required,min=1,max=100"`
|
|
Issuer string `json:"issuer" validate:"required,min=1,max=100"`
|
|
Secret string `json:"secret" validate:"required,otpsecret"`
|
|
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
|
|
Digits int `json:"digits" validate:"required,oneof=6 8"`
|
|
Period int `json:"period" validate:"required,oneof=30 60"`
|
|
}
|
|
|
|
// UpdateOTPRequest represents a request to update an OTP
|
|
type UpdateOTPRequest struct {
|
|
Name string `json:"name" validate:"omitempty,min=1,max=100"`
|
|
Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"`
|
|
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
|
|
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
|
|
Period int `json:"period" validate:"omitempty,oneof=30 60"`
|
|
}
|
|
|
|
// VerifyOTPRequest represents a request to verify an OTP code
|
|
type VerifyOTPRequest struct {
|
|
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
|
|
}
|