diff --git a/api/response.go b/api/response.go
new file mode 100644
index 0000000..8e24daf
--- /dev/null
+++ b/api/response.go
@@ -0,0 +1,149 @@
+package api
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+)
+
+// Common error codes
+const (
+ CodeSuccess = 0
+ CodeInvalidParams = 400
+ CodeUnauthorized = 401
+ CodeForbidden = 403
+ CodeNotFound = 404
+ CodeInternalError = 500
+ CodeServiceUnavail = 503
+)
+
+// Error represents an API error
+type Error struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
+
+// Error implements the error interface
+func (e *Error) Error() string {
+ return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
+}
+
+// NewError creates a new API error
+func NewError(code int, message string) *Error {
+ return &Error{
+ Code: code,
+ Message: message,
+ }
+}
+
+// Response represents a standard API response
+type Response struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data interface{} `json:"data,omitempty"`
+}
+
+// ResponseWriter wraps common response writing functions
+type ResponseWriter struct {
+ http.ResponseWriter
+}
+
+// NewResponseWriter creates a new ResponseWriter
+func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
+ return &ResponseWriter{w}
+}
+
+// WriteJSON writes a JSON response
+func (w *ResponseWriter) WriteJSON(code int, data interface{}) error {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(code)
+ return json.NewEncoder(w).Encode(data)
+}
+
+// WriteSuccess writes a success response
+func (w *ResponseWriter) WriteSuccess(data interface{}) error {
+ return w.WriteJSON(http.StatusOK, Response{
+ Code: CodeSuccess,
+ Message: "success",
+ Data: data,
+ })
+}
+
+// WriteError writes an error response
+func (w *ResponseWriter) WriteError(err error) error {
+ var apiErr *Error
+ if errors.As(err, &apiErr) {
+ return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{
+ Code: apiErr.Code,
+ Message: apiErr.Message,
+ })
+ }
+
+ // Handle unknown errors
+ return w.WriteJSON(http.StatusInternalServerError, Response{
+ Code: CodeInternalError,
+ Message: "Internal Server Error",
+ })
+}
+
+// WriteErrorWithCode writes an error response with a specific code
+func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error {
+ return w.WriteJSON(getHTTPStatus(code), Response{
+ Code: code,
+ Message: message,
+ })
+}
+
+// getHTTPStatus maps API error codes to HTTP status codes
+func getHTTPStatus(code int) int {
+ switch code {
+ case CodeSuccess:
+ return http.StatusOK
+ case CodeInvalidParams:
+ return http.StatusBadRequest
+ case CodeUnauthorized:
+ return http.StatusUnauthorized
+ case CodeForbidden:
+ return http.StatusForbidden
+ case CodeNotFound:
+ return http.StatusNotFound
+ case CodeServiceUnavail:
+ return http.StatusServiceUnavailable
+ default:
+ return http.StatusInternalServerError
+ }
+}
+
+// Common errors
+var (
+ ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters")
+ ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized")
+ ErrForbidden = NewError(CodeForbidden, "Forbidden")
+ ErrNotFound = NewError(CodeNotFound, "Resource not found")
+ ErrInternalError = NewError(CodeInternalError, "Internal server error")
+ ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable")
+)
+
+// ValidationError creates an error for invalid parameters
+func ValidationError(message string) *Error {
+ return NewError(CodeInvalidParams, message)
+}
+
+// NotFoundError creates an error for not found resources
+func NotFoundError(resource string) *Error {
+ return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource))
+}
+
+// ForbiddenError creates an error for forbidden actions
+func ForbiddenError(message string) *Error {
+ return NewError(CodeForbidden, message)
+}
+
+// InternalError creates an error for internal server errors
+func InternalError(err error) *Error {
+ if err == nil {
+ return ErrInternalError
+ }
+ return NewError(CodeInternalError, err.Error())
+}
diff --git a/api/validator.go b/api/validator.go
new file mode 100644
index 0000000..a18a552
--- /dev/null
+++ b/api/validator.go
@@ -0,0 +1,152 @@
+package api
+
+import (
+ "regexp"
+ "strings"
+
+ "github.com/go-playground/validator/v10"
+)
+
+// Validate is a global validator instance
+var Validate = validator.New()
+
+// RegisterCustomValidations registers custom validation functions
+func RegisterCustomValidations() {
+ // Register custom validation for issuer
+ Validate.RegisterValidation("issuer", validateIssuer)
+
+ // Register custom validation for XSS prevention
+ Validate.RegisterValidation("no_xss", validateNoXSS)
+
+ // Register custom validation for OTP secret
+ Validate.RegisterValidation("otpsecret", validateOTPSecret)
+}
+
+// validateOTPSecret validates that the OTP secret is in valid base32 format
+func validateOTPSecret(fl validator.FieldLevel) bool {
+ secret := fl.Field().String()
+
+ // Check if the secret is not empty
+ if secret == "" {
+ return false
+ }
+
+ // Check if the secret is in base32 format (A-Z, 2-7)
+ base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
+ if !base32Regex.MatchString(secret) {
+ return false
+ }
+
+ // Check if the length is valid (must be at least 16 characters)
+ if len(secret) < 16 || len(secret) > 128 {
+ return false
+ }
+
+ return true
+}
+
+// validateIssuer validates that the issuer field contains only allowed characters
+func validateIssuer(fl validator.FieldLevel) bool {
+ issuer := fl.Field().String()
+
+ // Empty issuer is valid (since it's optional)
+ if issuer == "" {
+ return true
+ }
+
+ // Allow alphanumeric characters, spaces, and common punctuation
+ issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
+
+import (
+ "regexp"
+ "strings"
+
+ "github.com/go-playground/validator/v10"
+)
+
+// Validate is a global validator instance
+var Validate = validator.New()
+
+// RegisterCustomValidations registers custom validation functions
+func RegisterCustomValidations() {
+ // Register custom validation for issuer
+ Validate.RegisterValidation("issuer", validateIssuer)
+
+ // Register custom validation for XSS prevention
+ Validate.RegisterValidation("no_xss", validateNoXSS)
+
+ // Register custom validation for OTP secret
+ Validate.RegisterValidation("otpsecret", validateOTPSecret)
+}
+
+// validateOTPSecret validates that the OTP secret is in valid base32 format
+func validateOTPSecret(fl validator.FieldLevel) bool {
+ secret := fl.Field().String()
+
+ // Check if the secret is not empty
+ if secret == "" {
+ return false
+ }
+
+ // Check if the secret is in base32 format (A-Z, 2-7)
+ base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
+ if !base32Regex.MatchString(secret) {
+ return false
+ }
+
+ // Check if the length is valid (must be at least 16 characters)
+ if len(secret) < 16 || len(secret) > 128 {
+ return false
+ }
+
+ return true
+}
+
+)
+ if !issuerRegex.MatchString(issuer) {
+ return false
+ }
+
+ // Check length
+ if len(issuer) > 100 {
+ return false
+ }
+
+ return true
+}
+
+// validateNoXSS validates that the field doesn't contain potential XSS payloads
+func validateNoXSS(fl validator.FieldLevel) bool {
+ value := fl.Field().String()
+
+ // Check for HTML encoding
+ if strings.Contains(value, "") ||
+ strings.Contains(value, "<") ||
+ strings.Contains(value, ">") {
+ return false
+ }
+
+ // Check for common XSS patterns
+ suspiciousPatterns := []*regexp.Regexp{
+ regexp.MustCompile(`(?i)`),
+ regexp.MustCompile(`(?i)javascript:`),
+ regexp.MustCompile(`(?i)data:text/html`),
+ regexp.MustCompile(`(?i)on\w+\s*=`),
+ regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
+ regexp.MustCompile(`(?i)<\s*iframe`),
+ regexp.MustCompile(`(?i)<\s*object`),
+ regexp.MustCompile(`(?i)<\s*embed`),
+ regexp.MustCompile(`(?i)<\s*style`),
+ regexp.MustCompile(`(?i)<\s*form`),
+ regexp.MustCompile(`(?i)<\s*applet`),
+ regexp.MustCompile(`(?i)<\s*meta`),
+ }
+
+ for _, pattern := range suspiciousPatterns {
+ if pattern.MatchString(value) {
+ return false
+ }
+ }
+
+ return true
+}
\ No newline at end of file
diff --git a/cache/cache.go b/cache/cache.go
new file mode 100644
index 0000000..08d3041
--- /dev/null
+++ b/cache/cache.go
@@ -0,0 +1,206 @@
+package cache
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+ "time"
+)
+
+// Item represents a cache item
+type Item struct {
+ Value []byte
+ Expiration int64
+}
+
+// Expired returns true if the item has expired
+func (item Item) Expired() bool {
+ if item.Expiration == 0 {
+ return false
+ }
+ return time.Now().UnixNano() > item.Expiration
+}
+
+// Cache represents an in-memory cache
+type Cache struct {
+ items map[string]Item
+ mu sync.RWMutex
+ defaultExpiration time.Duration
+ cleanupInterval time.Duration
+ stopCleanup chan bool
+}
+
+// New creates a new cache with the given default expiration and cleanup interval
+func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
+ cache := &Cache{
+ items: make(map[string]Item),
+ defaultExpiration: defaultExpiration,
+ cleanupInterval: cleanupInterval,
+ stopCleanup: make(chan bool),
+ }
+
+ // Start cleanup goroutine if cleanup interval > 0
+ if cleanupInterval > 0 {
+ go cache.startCleanup()
+ }
+
+ return cache
+}
+
+// startCleanup starts the cleanup process
+func (c *Cache) startCleanup() {
+ ticker := time.NewTicker(c.cleanupInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ c.DeleteExpired()
+ case <-c.stopCleanup:
+ return
+ }
+ }
+}
+
+// Set adds an item to the cache with the given key and expiration
+func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error {
+ // Convert value to bytes
+ var valueBytes []byte
+ var err error
+
+ switch v := value.(type) {
+ case []byte:
+ valueBytes = v
+ case string:
+ valueBytes = []byte(v)
+ default:
+ valueBytes, err = json.Marshal(value)
+ if err != nil {
+ return fmt.Errorf("failed to marshal value: %w", err)
+ }
+ }
+
+ // Calculate expiration
+ var exp int64
+ if expiration == 0 {
+ if c.defaultExpiration > 0 {
+ exp = time.Now().Add(c.defaultExpiration).UnixNano()
+ }
+ } else if expiration > 0 {
+ exp = time.Now().Add(expiration).UnixNano()
+ }
+
+ c.mu.Lock()
+ c.items[key] = Item{
+ Value: valueBytes,
+ Expiration: exp,
+ }
+ c.mu.Unlock()
+
+ return nil
+}
+
+// Get gets an item from the cache
+func (c *Cache) Get(key string, value interface{}) (bool, error) {
+ c.mu.RLock()
+ item, found := c.items[key]
+ c.mu.RUnlock()
+
+ if !found {
+ return false, nil
+ }
+
+ // Check if item has expired
+ if item.Expired() {
+ c.mu.Lock()
+ delete(c.items, key)
+ c.mu.Unlock()
+ return false, nil
+ }
+
+ // Unmarshal value
+ switch v := value.(type) {
+ case *[]byte:
+ *v = item.Value
+ case *string:
+ *v = string(item.Value)
+ default:
+ if err := json.Unmarshal(item.Value, value); err != nil {
+ return true, fmt.Errorf("failed to unmarshal value: %w", err)
+ }
+ }
+
+ return true, nil
+}
+
+// Delete deletes an item from the cache
+func (c *Cache) Delete(key string) {
+ c.mu.Lock()
+ delete(c.items, key)
+ c.mu.Unlock()
+}
+
+// DeleteExpired deletes all expired items from the cache
+func (c *Cache) DeleteExpired() {
+ now := time.Now().UnixNano()
+
+ c.mu.Lock()
+ for k, v := range c.items {
+ if v.Expiration > 0 && now > v.Expiration {
+ delete(c.items, k)
+ }
+ }
+ c.mu.Unlock()
+}
+
+// Clear deletes all items from the cache
+func (c *Cache) Clear() {
+ c.mu.Lock()
+ c.items = make(map[string]Item)
+ c.mu.Unlock()
+}
+
+// Close stops the cleanup goroutine
+func (c *Cache) Close() {
+ if c.cleanupInterval > 0 {
+ c.stopCleanup <- true
+ }
+}
+
+// CacheService provides caching functionality
+type CacheService struct {
+ cache *Cache
+}
+
+// NewCacheService creates a new CacheService
+func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService {
+ return &CacheService{
+ cache: New(defaultExpiration, cleanupInterval),
+ }
+}
+
+// Set adds an item to the cache
+func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
+ return s.cache.Set(key, value, expiration)
+}
+
+// Get gets an item from the cache
+func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) {
+ return s.cache.Get(key, value)
+}
+
+// Delete deletes an item from the cache
+func (s *CacheService) Delete(ctx context.Context, key string) {
+ s.cache.Delete(key)
+}
+
+// Clear deletes all items from the cache
+func (s *CacheService) Clear(ctx context.Context) {
+ s.cache.Clear()
+}
+
+// Close closes the cache
+func (s *CacheService) Close() {
+ s.cache.Close()
+}
diff --git a/cmd/root.go b/cmd/root.go
index 3752544..8821b91 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -1,83 +1,144 @@
package cmd
import (
+ "context"
"fmt"
"log"
"net/http"
"os"
+ "os/signal"
+ "syscall"
+
+ "github.com/spf13/viper"
+
+ "otpm/api"
+ "otpm/config"
"otpm/database"
"otpm/handlers"
- "otpm/utils"
-
- "github.com/jmoiron/sqlx"
- "github.com/julienschmidt/httprouter"
- "github.com/spf13/cobra"
- "github.com/spf13/viper"
+ "otpm/models"
+ "otpm/server"
+ "otpm/services"
)
-var rootCmd = &cobra.Command{
- Use: "otpm",
- Short: "otp backend for microapp on wechat",
- Run: func(cmd *cobra.Command, args []string) {
- startApp()
- },
-}
-
-func Execute() {
- if err := rootCmd.Execute(); err != nil {
- fmt.Println(err)
- os.Exit(1)
- }
-}
-
func init() {
- cobra.OnInitialize(initConfig)
+ // Set config file with multi-environment support
+ viper.SetConfigName("config")
+ viper.SetConfigType("yaml")
+ viper.AddConfigPath(".")
- rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)")
- rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)")
- rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string")
- rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on")
-
- viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver"))
- viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn"))
- viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
-}
-
-func initConfig() {
- if cfgFile := viper.GetString("config"); cfgFile != "" {
- viper.SetConfigFile(cfgFile)
- } else {
- viper.AddConfigPath(".")
- viper.SetConfigName("config")
- viper.SetConfigType("yaml")
+ // Set environment specific config (e.g. config.production.yaml)
+ env := os.Getenv("OTPM_ENV")
+ if env != "" {
+ viper.SetConfigName(fmt.Sprintf("config.%s", env))
}
- if err := viper.ReadInConfig(); err != nil {
- log.Fatalf("Error reading config file: %v", err)
- }
+ // Set default values
+ viper.SetDefault("server.port", "8080")
+ viper.SetDefault("server.timeout.read", "15s")
+ viper.SetDefault("server.timeout.write", "15s")
+ viper.SetDefault("server.timeout.idle", "60s")
+ viper.SetDefault("database.max_open_conns", 25)
+ viper.SetDefault("database.max_idle_conns", 5)
+ viper.SetDefault("database.conn_max_lifetime", "5m")
+
+ // Set environment variable prefix
+ viper.SetEnvPrefix("OTPM")
+ viper.AutomaticEnv()
+
+ // Bind environment variables
+ viper.BindEnv("database.url", "OTPM_DB_URL")
+ viper.BindEnv("database.password", "OTPM_DB_PASSWORD")
}
-func initApp(db *sqlx.DB) {
- if err := database.MigrateDB(db); err != nil {
- log.Fatalf("Error migrating the database: %v", err)
- }
-}
-
-func startApp() {
- port := viper.GetInt("port")
- db, err := database.InitDB()
+// Execute is the entry point for the application
+func Execute() error {
+ // Load configuration
+ cfg, err := config.LoadConfig()
if err != nil {
- log.Fatalf("Error connecting to the database: %v", err)
+ return fmt.Errorf("failed to load config: %w", err)
+ }
+
+ // Create context with cancellation
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ // Setup signal handling
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+ go func() {
+ sig := <-sigChan
+ log.Printf("Received signal: %v", sig)
+ cancel()
+ }()
+
+ // Initialize database
+ db, err := database.New(&cfg.Database)
+ if err != nil {
+ return fmt.Errorf("failed to initialize database: %w", err)
}
defer db.Close()
- initApp(db)
- handler := &handlers.Handler{DB: db}
- router := httprouter.New()
- router.POST("/login", utils.AdaptHandler(handler.Login))
- router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
- router.GET("/get", utils.AdaptHandler(handler.GetOtp))
+ // Run database migrations
+ if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil {
+ return fmt.Errorf("failed to run migrations: %w", err)
+ }
- log.Println("Starting server on :8080")
- log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), router))
+ // Initialize repositories
+ userRepo := models.NewUserRepository(db.DB)
+ otpRepo := models.NewOTPRepository(db.DB)
+
+ // Initialize services
+ authService := services.NewAuthService(cfg, userRepo)
+ otpService := services.NewOTPService(otpRepo)
+
+ // Register custom validations
+ api.RegisterCustomValidations()
+
+ // Initialize handlers
+ authHandler := handlers.NewAuthHandler(authService)
+ otpHandler := handlers.NewOTPHandler(otpService)
+
+ // Create and configure server
+ srv := server.New(cfg)
+
+ // Register health check endpoint
+ srv.RegisterHealthCheck()
+
+ // Register public routes with type conversion
+ authRoutes := make(map[string]http.Handler)
+ for path, handler := range authHandler.Routes() {
+ authRoutes[path] = http.HandlerFunc(handler)
+ }
+ srv.RegisterRoutes(authRoutes)
+
+ // Register authenticated routes with type conversion
+ otpRoutes := make(map[string]http.Handler)
+ for path, handler := range otpHandler.Routes() {
+ otpRoutes[path] = http.HandlerFunc(handler)
+ }
+ srv.RegisterAuthRoutes(otpRoutes)
+
+ // Start server in goroutine
+ serverErr := make(chan error, 1)
+ go func() {
+ log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port)
+ if err := srv.Start(); err != nil {
+ serverErr <- fmt.Errorf("server error: %w", err)
+ }
+ }()
+
+ // Wait for shutdown signal or server error
+ select {
+ case err := <-serverErr:
+ return err
+ case <-ctx.Done():
+ // Graceful shutdown with timeout
+ log.Println("Shutting down server...")
+ if err := srv.Shutdown(); err != nil {
+ return fmt.Errorf("server shutdown error: %w", err)
+ }
+ log.Println("Server stopped gracefully")
+ }
+
+ return nil
}
diff --git a/config.yaml b/config.yaml
index f23f8fb..da2013f 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,8 +1,23 @@
+server:
+ port: 8080
+ read_timeout: 15s
+ write_timeout: 15s
+ shutdown_timeout: 5s
+
database:
- driver: sqlite
+ driver: sqlite3
dsn: otpm.sqlite
-port: 8080
+ max_open_conns: 25
+ max_idle_conns: 25
+ max_lifetime: 5m
+ skip_migration: false
+
+jwt:
+ secret: "your-jwt-secret-key-change-this-in-production"
+ expire_delta: 24h
+ refresh_delta: 168h
+ signing_method: HS256
wechat:
- appid: "wx57d1033974eb5250"
- secret: "be494c2a81df685a40b9a74e1736b15d"
+ app_id: "your-wechat-app-id"
+ app_secret: "your-wechat-app-secret"
\ No newline at end of file
diff --git a/config/config.go b/config/config.go
new file mode 100644
index 0000000..4cad24d
--- /dev/null
+++ b/config/config.go
@@ -0,0 +1,128 @@
+package config
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/spf13/viper"
+)
+
+// Config holds all configuration for the application
+type Config struct {
+ Server ServerConfig `mapstructure:"server"`
+ Database DatabaseConfig `mapstructure:"database"`
+ JWT JWTConfig `mapstructure:"jwt"`
+ WeChat WeChatConfig `mapstructure:"wechat"`
+}
+
+// ServerConfig holds all server related configuration
+type ServerConfig struct {
+ Host string `mapstructure:"host"`
+ Port int `mapstructure:"port"`
+ ReadTimeout time.Duration `mapstructure:"read_timeout"`
+ WriteTimeout time.Duration `mapstructure:"write_timeout"`
+ ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
+ Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout
+}
+
+// DatabaseConfig holds all database related configuration
+type DatabaseConfig struct {
+ Driver string `mapstructure:"driver"`
+ DSN string `mapstructure:"dsn"`
+ MaxOpenConns int `mapstructure:"max_open_conns"`
+ MaxIdleConns int `mapstructure:"max_idle_conns"`
+ MaxLifetime time.Duration `mapstructure:"max_lifetime"`
+ SkipMigration bool `mapstructure:"skip_migration"`
+}
+
+// JWTConfig holds all JWT related configuration
+type JWTConfig struct {
+ Secret string `mapstructure:"secret"`
+ ExpireDelta time.Duration `mapstructure:"expire_delta"`
+ RefreshDelta time.Duration `mapstructure:"refresh_delta"`
+ SigningMethod string `mapstructure:"signing_method"`
+ Issuer string `mapstructure:"issuer"`
+ Audience string `mapstructure:"audience"`
+}
+
+// WeChatConfig holds all WeChat related configuration
+type WeChatConfig struct {
+ AppID string `mapstructure:"app_id"`
+ AppSecret string `mapstructure:"app_secret"`
+}
+
+// LoadConfig loads the configuration from file and environment variables
+func LoadConfig() (*Config, error) {
+ // Set default values
+ setDefaults()
+
+ // Read config file
+ if err := viper.ReadInConfig(); err != nil {
+ return nil, fmt.Errorf("failed to read config file: %w", err)
+ }
+
+ var config Config
+ if err := viper.Unmarshal(&config); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal config: %w", err)
+ }
+
+ // Validate config
+ if err := validateConfig(&config); err != nil {
+ return nil, fmt.Errorf("invalid configuration: %w", err)
+ }
+
+ return &config, nil
+}
+
+// setDefaults sets default values for configuration
+func setDefaults() {
+ // Server defaults
+ viper.SetDefault("server.port", 8080)
+ viper.SetDefault("server.read_timeout", "15s")
+ viper.SetDefault("server.write_timeout", "15s")
+ viper.SetDefault("server.shutdown_timeout", "5s")
+ viper.SetDefault("server.timeout", "30s") // Default request processing timeout
+
+ // Database defaults
+ viper.SetDefault("database.driver", "sqlite3")
+ viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
+ viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
+ viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
+ viper.SetDefault("database.skip_migration", false)
+
+ // JWT defaults
+ viper.SetDefault("jwt.expire_delta", "24h")
+ viper.SetDefault("jwt.refresh_delta", "168h") // 7 days
+ viper.SetDefault("jwt.signing_method", "HS256")
+ viper.SetDefault("jwt.issuer", "otpm")
+ viper.SetDefault("jwt.audience", "otpm-client")
+}
+
+// validateConfig validates the configuration
+func validateConfig(config *Config) error {
+ if config.Server.Port < 1 || config.Server.Port > 65535 {
+ return fmt.Errorf("invalid port number: %d", config.Server.Port)
+ }
+
+ if config.Database.Driver == "" {
+ return fmt.Errorf("database driver is required")
+ }
+
+ if config.Database.DSN == "" {
+ return fmt.Errorf("database DSN is required")
+ }
+
+ if config.JWT.Secret == "" {
+ return fmt.Errorf("JWT secret is required")
+ }
+
+ if config.WeChat.AppID == "" {
+ return fmt.Errorf("WeChat AppID is required")
+ }
+
+ if config.WeChat.AppSecret == "" {
+ return fmt.Errorf("WeChat AppSecret is required")
+ }
+
+ return nil
+}
diff --git a/database/database.go b/database/database.go
deleted file mode 100644
index bf4ade8..0000000
--- a/database/database.go
+++ /dev/null
@@ -1,49 +0,0 @@
-package database
-
-import (
- _ "embed"
- "log"
-
- _ "github.com/go-sql-driver/mysql"
- "github.com/jmoiron/sqlx"
- _ "github.com/lib/pq"
- "github.com/spf13/viper"
- _ "modernc.org/sqlite"
-)
-
-var (
- //go:embed init/users.sql
- userTable string
- //go:embed init/otp.sql
- otpTable string
-)
-
-func InitDB() (*sqlx.DB, error) {
- driver := viper.GetString("database.driver")
- dsn := viper.GetString("database.dsn")
-
- db, err := sqlx.Open(driver, dsn)
- if err != nil {
- return nil, err
- }
-
- if err := db.Ping(); err != nil {
- return nil, err
- }
-
- log.Println("Connected to database!")
- return db, nil
-}
-
-func MigrateDB(db *sqlx.DB) error {
- _, err := db.Exec(userTable)
- if err != nil {
- return err
- }
-
- _, err = db.Exec(otpTable)
- if err != nil {
- return err
- }
- return nil
-}
diff --git a/database/db.go b/database/db.go
new file mode 100644
index 0000000..e882713
--- /dev/null
+++ b/database/db.go
@@ -0,0 +1,212 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "log"
+ "strings"
+ "time"
+
+ "otpm/config"
+
+ "github.com/jmoiron/sqlx"
+)
+
+// DB wraps sqlx.DB to provide additional functionality
+type DB struct {
+ *sqlx.DB
+}
+
+// New creates a new database connection
+func New(cfg *config.DatabaseConfig) (*DB, error) {
+ db, err := sqlx.Open(cfg.Driver, cfg.DSN)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to database: %w", err)
+ }
+
+ // Configure connection pool based on database type
+ if cfg.Driver == "sqlite3" {
+ // SQLite is a file-based database - simpler connection settings
+ db.SetMaxOpenConns(1)
+ db.SetMaxIdleConns(1)
+ db.SetConnMaxLifetime(0) // Connections don't need to be recycled
+ db.SetConnMaxIdleTime(0)
+ } else {
+ // For other databases (MySQL, PostgreSQL etc.)
+ db.SetMaxOpenConns(cfg.MaxOpenConns)
+ db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
+ db.SetConnMaxLifetime(30 * time.Minute)
+ db.SetConnMaxIdleTime(5 * time.Minute)
+ }
+
+ // Verify connection with timeout
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if err := db.PingContext(ctx); err != nil {
+ return nil, fmt.Errorf("failed to ping database: %w", err)
+ }
+
+ log.Println("Successfully connected to database")
+ return &DB{db}, nil
+}
+
+// WithTx executes a function within a transaction with retry logic
+func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
+ var maxRetries int
+ var lastErr error
+
+ // Adjust retry settings based on database type
+ if db.DriverName() == "sqlite3" {
+ maxRetries = 5 // SQLite needs more retries due to busy timeouts
+ } else {
+ maxRetries = 3
+ }
+
+ // Default transaction options
+ opts := &sql.TxOptions{
+ Isolation: sql.LevelReadCommitted,
+ }
+
+ for attempt := 1; attempt <= maxRetries; attempt++ {
+ start := time.Now()
+
+ tx, err := db.BeginTxx(ctx, opts)
+ if err != nil {
+ if isRetryableError(err) && attempt < maxRetries {
+ lastErr = err
+ time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
+ continue
+ }
+ return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
+ }
+
+ defer func() {
+ if p := recover(); p != nil {
+ tx.Rollback()
+ panic(p)
+ }
+ }()
+
+ if err := fn(tx); err != nil {
+ if rbErr := tx.Rollback(); rbErr != nil {
+ log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
+ }
+
+ if isRetryableError(err) && attempt < maxRetries {
+ lastErr = err
+ time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
+ continue
+ }
+ return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
+ }
+
+ // Log long-running transactions
+ if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
+ log.Printf("Transaction completed in %v", elapsed)
+ }
+
+ if err := tx.Commit(); err != nil {
+ if isRetryableError(err) && attempt < maxRetries {
+ lastErr = err
+ time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
+ continue
+ }
+ return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
+ }
+
+ return nil
+ }
+
+ return lastErr
+}
+
+// isRetryableError checks if an error is likely to succeed on retry
+func isRetryableError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ errStr := strings.ToLower(err.Error())
+ return strings.Contains(errStr, "deadlock") ||
+ strings.Contains(errStr, "timeout") ||
+ strings.Contains(errStr, "try again") ||
+ strings.Contains(errStr, "connection reset") ||
+ strings.Contains(errStr, "busy") ||
+ strings.Contains(errStr, "locked")
+}
+
+// ExecContext executes a query with adaptive timeout
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
+ // Set timeout based on query complexity
+ timeout := 5 * time.Second
+ if strings.Contains(strings.ToLower(query), "insert") ||
+ strings.Contains(strings.ToLower(query), "update") ||
+ strings.Contains(strings.ToLower(query), "delete") {
+ timeout = 10 * time.Second
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ start := time.Now()
+ _, err := db.DB.ExecContext(ctx, query, args...)
+ elapsed := time.Since(start)
+
+ // Log slow queries
+ if elapsed > timeout/2 {
+ log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
+ }
+
+ if err != nil {
+ return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
+ }
+
+ return nil
+}
+
+// QueryRowContext executes a query that returns a single row with timeout
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ return db.DB.QueryRowxContext(ctx, query, args...)
+}
+
+// QueryContext executes a query that returns multiple rows with adaptive timeout
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
+ // Set timeout based on query complexity
+ timeout := 5 * time.Second
+ if strings.Contains(strings.ToLower(query), "join") ||
+ strings.Contains(strings.ToLower(query), "group by") ||
+ strings.Contains(strings.ToLower(query), "order by") {
+ timeout = 15 * time.Second
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ start := time.Now()
+ rows, err := db.DB.QueryxContext(ctx, query, args...)
+ elapsed := time.Since(start)
+
+ // Log slow queries
+ if elapsed > timeout/2 {
+ log.Printf("Slow query detected: %s (took %v)", query, elapsed)
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
+ }
+
+ return rows, nil
+}
+
+// Close closes the database connection
+func (db *DB) Close() error {
+ if err := db.DB.Close(); err != nil {
+ return fmt.Errorf("failed to close database connection: %w", err)
+ }
+ return nil
+}
diff --git a/database/init/otp.sql b/database/init/otp.sql
index a9b2442..e0e1ba9 100644
--- a/database/init/otp.sql
+++ b/database/init/otp.sql
@@ -1,6 +1,26 @@
CREATE TABLE IF NOT EXISTS otp (
- id SERIAL PRIMARY KEY,
- openid VARCHAR(255),
- num INTEGER,
- token VARCHAR(255)
-);
\ No newline at end of file
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id VARCHAR(255) NOT NULL,
+ openid VARCHAR(255) NOT NULL,
+ name VARCHAR(100) NOT NULL,
+ issuer VARCHAR(255),
+ secret VARCHAR(255) NOT NULL,
+ algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
+ digits INTEGER NOT NULL DEFAULT 6,
+ period INTEGER NOT NULL DEFAULT 30,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, name),
+ UNIQUE(openid)
+);
+
+-- Add index for faster lookups
+CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
+CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
+
+-- Trigger to update the updated_at timestamp
+CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
+ AFTER UPDATE ON otp
+BEGIN
+ UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
+END;
\ No newline at end of file
diff --git a/database/init/users.sql b/database/init/users.sql
index 4dbedae..ae4532a 100644
--- a/database/init/users.sql
+++ b/database/init/users.sql
@@ -1,5 +1,6 @@
CREATE TABLE IF NOT EXISTS users (
- id SERIAL PRIMARY KEY,
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
-);
\ No newline at end of file
+);
+CREATE UNIQUE INDEX idx_users_openid ON users(openid);
\ No newline at end of file
diff --git a/database/migration.go b/database/migration.go
new file mode 100644
index 0000000..df15a97
--- /dev/null
+++ b/database/migration.go
@@ -0,0 +1,160 @@
+package database
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "time"
+
+ _ "embed"
+
+ "github.com/jmoiron/sqlx"
+)
+
+var (
+ //go:embed init/users.sql
+ userTable string
+ //go:embed init/otp.sql
+ otpTable string
+)
+
+// Migration represents a database migration
+type Migration struct {
+ Name string
+ SQL string
+ Version int
+}
+
+// Migrations is a list of all migrations
+var Migrations = []Migration{
+ {
+ Name: "Create users table",
+ SQL: userTable,
+ Version: 1,
+ },
+ {
+ Name: "Create OTP table",
+ SQL: otpTable,
+ Version: 2,
+ },
+}
+
+// MigrationRecord represents a record in the migrations table
+type MigrationRecord struct {
+ ID int `db:"id"`
+ Version int `db:"version"`
+ Name string `db:"name"`
+ AppliedAt time.Time `db:"applied_at"`
+}
+
+// ensureMigrationsTable ensures that the migrations table exists
+func ensureMigrationsTable(db *sqlx.DB) error {
+ query := `
+ CREATE TABLE IF NOT EXISTS migrations (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ version INTEGER NOT NULL,
+ name TEXT NOT NULL,
+ applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+ );
+ `
+
+ _, err := db.Exec(query)
+ if err != nil {
+ return fmt.Errorf("failed to create migrations table: %w", err)
+ }
+
+ return nil
+}
+
+// getAppliedMigrations gets all applied migrations
+func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
+ var records []MigrationRecord
+ if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
+ return nil, fmt.Errorf("failed to get applied migrations: %w", err)
+ }
+
+ result := make(map[int]MigrationRecord)
+ for _, record := range records {
+ result[record.Version] = record
+ }
+
+ return result, nil
+}
+
+// Migrate runs all pending migrations
+func Migrate(db *sqlx.DB, skipMigration bool) error {
+ if skipMigration {
+ log.Println("Skipping database migration as configured")
+ return nil
+ }
+
+ // Ensure migrations table exists
+ if err := ensureMigrationsTable(db); err != nil {
+ return err
+ }
+
+ // Get applied migrations
+ appliedMigrations, err := getAppliedMigrations(db)
+ if err != nil {
+ return err
+ }
+
+ // Apply pending migrations
+ for _, migration := range Migrations {
+ if _, ok := appliedMigrations[migration.Version]; ok {
+ log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
+ continue
+ }
+
+ log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
+
+ // Start a transaction for this migration
+ tx, err := db.Beginx()
+ if err != nil {
+ return fmt.Errorf("failed to begin transaction: %w", err)
+ }
+
+ // Execute migration
+ if _, err := tx.Exec(migration.SQL); err != nil {
+ tx.Rollback()
+ return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
+ }
+
+ // Record migration
+ if _, err := tx.Exec(
+ "INSERT INTO migrations (version, name) VALUES (?, ?)",
+ migration.Version, migration.Name,
+ ); err != nil {
+ tx.Rollback()
+ return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
+ }
+
+ // Commit transaction
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
+ }
+
+ log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
+ }
+
+ return nil
+}
+
+// MigrateWithContext runs all pending migrations with context
+func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
+ // Create a channel to signal completion
+ done := make(chan error, 1)
+
+ // Run migration in a goroutine
+ go func() {
+ done <- Migrate(db, skipMigration)
+ }()
+
+ // Wait for migration to complete or context to be canceled
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ return fmt.Errorf("migration canceled: %w", ctx.Err())
+ }
+}
diff --git a/docs/swagger.go b/docs/swagger.go
new file mode 100644
index 0000000..ccf05ac
--- /dev/null
+++ b/docs/swagger.go
@@ -0,0 +1,663 @@
+// Package docs provides API documentation using Swagger/OpenAPI
+package docs
+
+import (
+ "encoding/json"
+ "net/http"
+)
+
+// SwaggerInfo holds the API information
+var SwaggerInfo = struct {
+ Title string
+ Description string
+ Version string
+ Host string
+ BasePath string
+ Schemes []string
+}{
+ Title: "OTPM API",
+ Description: "API for One-Time Password Manager",
+ Version: "1.0.0",
+ Host: "localhost:8080",
+ BasePath: "/",
+ Schemes: []string{"http", "https"},
+}
+
+// SwaggerJSON returns the Swagger JSON
+func SwaggerJSON() []byte {
+ swagger := map[string]interface{}{
+ "swagger": "2.0",
+ "info": map[string]interface{}{
+ "title": SwaggerInfo.Title,
+ "description": SwaggerInfo.Description,
+ "version": SwaggerInfo.Version,
+ },
+ "host": SwaggerInfo.Host,
+ "basePath": SwaggerInfo.BasePath,
+ "schemes": SwaggerInfo.Schemes,
+ "paths": getPaths(),
+ "definitions": map[string]interface{}{
+ "LoginRequest": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "code": map[string]interface{}{
+ "type": "string",
+ "description": "WeChat authorization code",
+ },
+ },
+ "required": []string{"code"},
+ },
+ "LoginResponse": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "token": map[string]interface{}{
+ "type": "string",
+ "description": "JWT token",
+ },
+ "openid": map[string]interface{}{
+ "type": "string",
+ "description": "WeChat OpenID",
+ },
+ },
+ },
+ "CreateOTPRequest": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "name": map[string]interface{}{
+ "type": "string",
+ "description": "OTP name",
+ },
+ "issuer": map[string]interface{}{
+ "type": "string",
+ "description": "OTP issuer",
+ },
+ "secret": map[string]interface{}{
+ "type": "string",
+ "description": "OTP secret",
+ },
+ "algorithm": map[string]interface{}{
+ "type": "string",
+ "description": "OTP algorithm",
+ "enum": []string{"SHA1", "SHA256", "SHA512"},
+ },
+ "digits": map[string]interface{}{
+ "type": "integer",
+ "description": "OTP digits",
+ "enum": []int{6, 8},
+ },
+ "period": map[string]interface{}{
+ "type": "integer",
+ "description": "OTP period in seconds",
+ "enum": []int{30, 60},
+ },
+ },
+ "required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"},
+ },
+ "OTP": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "id": map[string]interface{}{
+ "type": "string",
+ "description": "OTP ID",
+ },
+ "user_id": map[string]interface{}{
+ "type": "string",
+ "description": "User ID",
+ },
+ "name": map[string]interface{}{
+ "type": "string",
+ "description": "OTP name",
+ },
+ "issuer": map[string]interface{}{
+ "type": "string",
+ "description": "OTP issuer",
+ },
+ "algorithm": map[string]interface{}{
+ "type": "string",
+ "description": "OTP algorithm",
+ },
+ "digits": map[string]interface{}{
+ "type": "integer",
+ "description": "OTP digits",
+ },
+ "period": map[string]interface{}{
+ "type": "integer",
+ "description": "OTP period in seconds",
+ },
+ "created_at": map[string]interface{}{
+ "type": "string",
+ "format": "date-time",
+ "description": "Creation time",
+ },
+ "updated_at": map[string]interface{}{
+ "type": "string",
+ "format": "date-time",
+ "description": "Last update time",
+ },
+ },
+ },
+ "OTPCodeResponse": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "code": map[string]interface{}{
+ "type": "string",
+ "description": "OTP code",
+ },
+ "expires_in": map[string]interface{}{
+ "type": "integer",
+ "description": "Seconds until expiration",
+ },
+ },
+ },
+ "VerifyOTPRequest": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "code": map[string]interface{}{
+ "type": "string",
+ "description": "OTP code to verify",
+ },
+ },
+ "required": []string{"code"},
+ },
+ "VerifyOTPResponse": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "valid": map[string]interface{}{
+ "type": "boolean",
+ "description": "Whether the code is valid",
+ },
+ },
+ },
+ "UpdateOTPRequest": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "name": map[string]interface{}{
+ "type": "string",
+ "description": "OTP name",
+ },
+ "issuer": map[string]interface{}{
+ "type": "string",
+ "description": "OTP issuer",
+ },
+ "algorithm": map[string]interface{}{
+ "type": "string",
+ "description": "OTP algorithm",
+ "enum": []string{"SHA1", "SHA256", "SHA512"},
+ },
+ "digits": map[string]interface{}{
+ "type": "integer",
+ "description": "OTP digits",
+ "enum": []int{6, 8},
+ },
+ "period": map[string]interface{}{
+ "type": "integer",
+ "description": "OTP period in seconds",
+ "enum": []int{30, 60},
+ },
+ },
+ },
+ "ErrorResponse": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "code": map[string]interface{}{
+ "type": "integer",
+ "description": "Error code",
+ },
+ "message": map[string]interface{}{
+ "type": "string",
+ "description": "Error message",
+ },
+ },
+ },
+ },
+ "securityDefinitions": map[string]interface{}{
+ "Bearer": map[string]interface{}{
+ "type": "apiKey",
+ "name": "Authorization",
+ "in": "header",
+ "description": "JWT token with Bearer prefix",
+ },
+ },
+ }
+
+ data, _ := json.MarshalIndent(swagger, "", " ")
+ return data
+}
+
+// getPaths returns the API paths
+func getPaths() map[string]interface{} {
+ return map[string]interface{}{
+ "/login": map[string]interface{}{
+ "post": map[string]interface{}{
+ "summary": "Login with WeChat",
+ "description": "Login with WeChat authorization code",
+ "tags": []string{"auth"},
+ "parameters": []map[string]interface{}{
+ {
+ "name": "body",
+ "in": "body",
+ "description": "Login request",
+ "required": true,
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/LoginRequest",
+ },
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "Successful login",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/LoginResponse",
+ },
+ },
+ "400": map[string]interface{}{
+ "description": "Invalid request",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ },
+ "/verify-token": map[string]interface{}{
+ "post": map[string]interface{}{
+ "summary": "Verify token",
+ "description": "Verify JWT token",
+ "tags": []string{"auth"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "Token is valid",
+ "schema": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "valid": map[string]interface{}{
+ "type": "boolean",
+ "description": "Whether the token is valid",
+ },
+ },
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ },
+ "/otp": map[string]interface{}{
+ "get": map[string]interface{}{
+ "summary": "List OTPs",
+ "description": "List all OTPs for the authenticated user",
+ "tags": []string{"otp"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "List of OTPs",
+ "schema": map[string]interface{}{
+ "type": "array",
+ "items": map[string]interface{}{
+ "$ref": "#/definitions/OTP",
+ },
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ "post": map[string]interface{}{
+ "summary": "Create OTP",
+ "description": "Create a new OTP",
+ "tags": []string{"otp"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "parameters": []map[string]interface{}{
+ {
+ "name": "body",
+ "in": "body",
+ "description": "OTP creation request",
+ "required": true,
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/CreateOTPRequest",
+ },
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "OTP created",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/OTP",
+ },
+ },
+ "400": map[string]interface{}{
+ "description": "Invalid request",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ },
+ "/otp/{id}": map[string]interface{}{
+ "put": map[string]interface{}{
+ "summary": "Update OTP",
+ "description": "Update an existing OTP",
+ "tags": []string{"otp"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "parameters": []map[string]interface{}{
+ {
+ "name": "id",
+ "in": "path",
+ "description": "OTP ID",
+ "required": true,
+ "type": "string",
+ },
+ {
+ "name": "body",
+ "in": "body",
+ "description": "OTP update request",
+ "required": true,
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/UpdateOTPRequest",
+ },
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "OTP updated",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/OTP",
+ },
+ },
+ "400": map[string]interface{}{
+ "description": "Invalid request",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "404": map[string]interface{}{
+ "description": "OTP not found",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ "delete": map[string]interface{}{
+ "summary": "Delete OTP",
+ "description": "Delete an existing OTP",
+ "tags": []string{"otp"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "parameters": []map[string]interface{}{
+ {
+ "name": "id",
+ "in": "path",
+ "description": "OTP ID",
+ "required": true,
+ "type": "string",
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "OTP deleted",
+ "schema": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "message": map[string]interface{}{
+ "type": "string",
+ "description": "Success message",
+ },
+ },
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "404": map[string]interface{}{
+ "description": "OTP not found",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ },
+ "/otp/{id}/code": map[string]interface{}{
+ "get": map[string]interface{}{
+ "summary": "Get OTP code",
+ "description": "Get the current OTP code",
+ "tags": []string{"otp"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "parameters": []map[string]interface{}{
+ {
+ "name": "id",
+ "in": "path",
+ "description": "OTP ID",
+ "required": true,
+ "type": "string",
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "OTP code",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/OTPCodeResponse",
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "404": map[string]interface{}{
+ "description": "OTP not found",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ },
+ "/otp/{id}/verify": map[string]interface{}{
+ "post": map[string]interface{}{
+ "summary": "Verify OTP code",
+ "description": "Verify an OTP code",
+ "tags": []string{"otp"},
+ "security": []map[string]interface{}{
+ {
+ "Bearer": []string{},
+ },
+ },
+ "parameters": []map[string]interface{}{
+ {
+ "name": "id",
+ "in": "path",
+ "description": "OTP ID",
+ "required": true,
+ "type": "string",
+ },
+ {
+ "name": "body",
+ "in": "body",
+ "description": "OTP verification request",
+ "required": true,
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/VerifyOTPRequest",
+ },
+ },
+ },
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "OTP verification result",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/VerifyOTPResponse",
+ },
+ },
+ "400": map[string]interface{}{
+ "description": "Invalid request",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "401": map[string]interface{}{
+ "description": "Unauthorized",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "404": map[string]interface{}{
+ "description": "OTP not found",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ "500": map[string]interface{}{
+ "description": "Internal server error",
+ "schema": map[string]interface{}{
+ "$ref": "#/definitions/ErrorResponse",
+ },
+ },
+ },
+ },
+ },
+ "/health": map[string]interface{}{
+ "get": map[string]interface{}{
+ "summary": "Health check",
+ "description": "Check if the API is healthy",
+ "tags": []string{"system"},
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "API is healthy",
+ "schema": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "status": map[string]interface{}{
+ "type": "string",
+ "description": "Health status",
+ },
+ "time": map[string]interface{}{
+ "type": "string",
+ "format": "date-time",
+ "description": "Current time",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ "/metrics": map[string]interface{}{
+ "get": map[string]interface{}{
+ "summary": "Metrics",
+ "description": "Get application metrics",
+ "tags": []string{"system"},
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "Application metrics",
+ },
+ },
+ },
+ },
+ "/swagger.json": map[string]interface{}{
+ "get": map[string]interface{}{
+ "summary": "Swagger JSON",
+ "description": "Get Swagger JSON",
+ "tags": []string{"system"},
+ "responses": map[string]interface{}{
+ "200": map[string]interface{}{
+ "description": "Swagger JSON",
+ },
+ },
+ },
+ },
+ }
+}
+
+// Handler returns an HTTP handler for Swagger JSON
+func Handler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write(SwaggerJSON())
+ }
+}
diff --git a/go.mod b/go.mod
index 57bb068..a55608f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,31 +1,36 @@
module otpm
-go 1.21.1
+go 1.23.0
+
+toolchain go1.23.9
require (
- github.com/go-sql-driver/mysql v1.8.1
+ github.com/go-playground/validator/v10 v10.26.0
+ github.com/golang-jwt/jwt v3.2.2+incompatible
+ github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/julienschmidt/httprouter v1.3.0
- github.com/lib/pq v1.10.9
- github.com/spf13/cobra v1.8.1
+ github.com/prometheus/client_golang v1.22.0
github.com/spf13/viper v1.19.0
- modernc.org/sqlite v1.32.0
+ golang.org/x/crypto v0.38.0
)
require (
- filippo.io/edwards25519 v1.1.0 // indirect
- github.com/dustin/go-humanize v1.0.1 // indirect
+ github.com/beorn7/perks v1.0.1 // indirect
+ github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
- github.com/google/uuid v1.6.0 // indirect
- github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
+ github.com/gabriel-vasile/mimetype v1.4.8 // indirect
+ github.com/go-playground/locales v0.14.1 // indirect
+ github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
- github.com/inconshreveable/mousetrap v1.1.0 // indirect
+ github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
- github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
- github.com/ncruces/go-strftime v0.1.9 // indirect
+ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
- github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/prometheus/client_model v0.6.1 // indirect
+ github.com/prometheus/common v0.62.0 // indirect
+ github.com/prometheus/procfs v0.15.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
@@ -36,14 +41,10 @@ require (
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
- golang.org/x/sys v0.22.0 // indirect
- golang.org/x/text v0.14.0 // indirect
+ golang.org/x/net v0.34.0 // indirect
+ golang.org/x/sys v0.33.0 // indirect
+ golang.org/x/text v0.25.0 // indirect
+ google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
- modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
- modernc.org/libc v1.55.3 // indirect
- modernc.org/mathutil v1.6.0 // indirect
- modernc.org/memory v1.8.0 // indirect
- modernc.org/strutil v1.2.0 // indirect
- modernc.org/token v1.1.0 // indirect
)
diff --git a/go.sum b/go.sum
index 5dc30bc..c30de80 100644
--- a/go.sum
+++ b/go.sum
@@ -1,60 +1,76 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
-github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
+github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
+github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
+github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
-github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
+github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
+github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
+github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
+github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
+github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
+github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
+github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
+github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
+github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
-github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
-github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
-github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
-github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
+github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
+github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
+github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
+github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
-github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
-github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
-github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
+github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
+github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
+github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
+github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
-github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
-github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
-github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
-github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
-github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
-github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
+github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
+github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
+github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
+github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
+github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
+github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
+github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
@@ -65,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
-github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
-github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
@@ -79,56 +93,32 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
+golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
+golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
-golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
-golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
-golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
-golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
+golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
+golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
+golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
+golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
+golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
+google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
+google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ=
-modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
-modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y=
-modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s=
-modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
-modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
-modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
-modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
-modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
-modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
-modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U=
-modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w=
-modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
-modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
-modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
-modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
-modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
-modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
-modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
-modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
-modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s=
-modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA=
-modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=
-modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0=
-modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
-modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go
new file mode 100644
index 0000000..f79979f
--- /dev/null
+++ b/handlers/auth_handler.go
@@ -0,0 +1,156 @@
+package handlers
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "time"
+
+ "otpm/api"
+ "otpm/services"
+
+ "github.com/golang-jwt/jwt"
+ "github.com/julienschmidt/httprouter"
+)
+
+// AuthHandler handles authentication related requests
+type AuthHandler struct {
+ authService *services.AuthService
+}
+
+// NewAuthHandler creates a new AuthHandler
+func NewAuthHandler(authService *services.AuthService) *AuthHandler {
+ return &AuthHandler{
+ authService: authService,
+ }
+}
+
+// LoginRequest represents a login request
+type LoginRequest struct {
+ Code string `json:"code" validate:"required,min=32,max=128"`
+}
+
+// LoginResponse represents a login response
+type LoginResponse struct {
+ Token string `json:"token"`
+ OpenID string `json:"openid"`
+}
+
+// TokenRequest represents a token verification request
+type TokenRequest struct {
+ Token string `validate:"required,min=32"`
+}
+
+// Login handles WeChat login
+func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ start := time.Now()
+
+ // Limit request body size to prevent DOS
+ r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
+
+ // Parse and validate request
+ var req LoginRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
+ fmt.Sprintf("Invalid request body: %v", err))
+ log.Printf("Login request parse error: %v", err)
+ return
+ }
+
+ // Validate using validator
+ if err := api.Validate.Struct(req); err != nil {
+ api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
+ fmt.Sprintf("Invalid request parameters: %v", err))
+ log.Printf("Login request validation failed: %v", err)
+ return
+ }
+
+ // Login with WeChat code
+ token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
+ if err != nil {
+ api.NewResponseWriter(w).WriteError(api.InternalError(err))
+ log.Printf("Login failed for code %s: %v", req.Code, err)
+ return
+ }
+
+ // Log successful login
+ log.Printf("Login successful for code %s (took %v)",
+ req.Code, time.Since(start))
+
+ // Return token
+ api.NewResponseWriter(w).WriteSuccess(LoginResponse{
+ Token: token,
+ })
+}
+
+// VerifyToken handles token verification
+func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ start := time.Now()
+
+ // Get token from Authorization header
+ authHeader := r.Header.Get("Authorization")
+ if authHeader == "" {
+ api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
+ "Authorization header is required")
+ log.Printf("Token verification failed: missing Authorization header")
+ return
+ }
+
+ // Validate token format
+ if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
+ api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
+ "Invalid token format. Expected 'Bearer '")
+ log.Printf("Token verification failed: invalid token format")
+ return
+ }
+
+ token := authHeader[7:]
+
+ // Validate token using validator
+ tokenReq := TokenRequest{Token: token}
+ if err := api.Validate.Struct(tokenReq); err != nil {
+ api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
+ "Invalid token format")
+ log.Printf("Token verification failed: %v", err)
+ return
+ }
+
+ // Validate token
+ claims, err := h.authService.ValidateToken(token)
+ if err != nil {
+ api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
+ log.Printf("Token verification failed for token %s: %v",
+ maskToken(token), err) // Mask token in logs
+ return
+ }
+
+ // Log successful verification
+ userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
+ if !ok {
+ log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
+ } else {
+ log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
+ }
+
+ // Token is valid
+ api.NewResponseWriter(w).WriteSuccess(map[string]bool{
+ "valid": true,
+ })
+}
+
+// maskToken masks sensitive parts of token for logging
+func maskToken(token string) string {
+ if len(token) < 8 {
+ return "****"
+ }
+ return token[:4] + "****" + token[len(token)-4:]
+}
+
+// Routes returns all routes for the auth handler
+func (h *AuthHandler) Routes() map[string]httprouter.Handle {
+ return map[string]httprouter.Handle{
+ "/api/login": h.Login,
+ "/api/verify-token": h.VerifyToken,
+ }
+}
diff --git a/handlers/handler.go b/handlers/handler.go
deleted file mode 100644
index db15938..0000000
--- a/handlers/handler.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package handlers
-
-import (
- "github.com/jmoiron/sqlx"
-)
-
-type Handler struct {
- DB *sqlx.DB
-}
diff --git a/handlers/login.go b/handlers/login.go
deleted file mode 100644
index 30812d7..0000000
--- a/handlers/login.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package handlers
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
-
- "github.com/spf13/viper"
-)
-
-var code2Session = "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code"
-
-type LoginRequest struct {
- Code string `json:"code"`
-}
-
-// 封装code2session接口返回数据
-type LoginResponse struct {
- OpenId string `json:"openid"`
- SessionKey string `json:"session_key"`
- UnionId string `json:"unionid"`
- ErrCode int `json:"errcode"`
- ErrMsg string `json:"errmsg"`
-}
-
-func getLoginResponse(code string) (*LoginResponse, error) {
- appid := viper.GetString("wechat.appid")
- secret := viper.GetString("wechat.secret")
- url := fmt.Sprintf(code2Session, appid, secret, code)
- resp, err := http.Get(url)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
-
- var loginResponse LoginResponse
- if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
- return nil, err
- }
-
- if loginResponse.ErrCode != 0 {
- return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
- }
- return &loginResponse, nil
-}
-
-func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
-
- var req LoginRequest
- body, err := io.ReadAll(r.Body)
- if err != nil {
- http.Error(w, "Failed to read request body", http.StatusBadRequest)
- return
- }
- if err := json.Unmarshal(body, &req); err != nil {
- http.Error(w, "Failed to parse request body", http.StatusBadRequest)
- return
- }
-
- loginResponse, err := getLoginResponse(req.Code)
- if err != nil {
- http.Error(w, "Failed to get session key", http.StatusInternalServerError)
- return
- }
-
- // // 插入或更新用户的openid和sessionid
- // query := `
- // INSERT INTO users (openid, sessionid)
- // VALUES ($1, $2)
- // ON CONFLICT (openid) DO UPDATE SET sessionid = $2
- // RETURNING id;
- // `
-
- // var ID int
- // if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
- // http.Error(w, "Failed to log in user", http.StatusInternalServerError)
- // return
- // }
-
- data := map[string]interface{}{
- "openid": loginResponse.OpenId,
- "session_key": loginResponse.SessionKey,
- }
- respData, err := json.Marshal(data)
- if err != nil {
- http.Error(w, "Failed to marshal response data", http.StatusInternalServerError)
- return
- }
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- w.Write([]byte(respData))
-}
diff --git a/handlers/otp.go b/handlers/otp.go
deleted file mode 100644
index 61230ec..0000000
--- a/handlers/otp.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package handlers
-
-import (
- "encoding/json"
- "net/http"
-)
-
-type OtpRequest struct {
- OpenID string `json:"openid"`
- Num int `json:"num"`
- Token *[]OTP `json:"token"`
-}
-type OTP struct {
- Issuer string `json:"issuer"`
- Remark string `json:"remark"`
- Secret string `json:"secret"`
-}
-
-func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) {
- var req OtpRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- http.Error(w, "Invalid request payload", http.StatusBadRequest)
- return
- }
-
- num := len(*req.Token)
-
- // 插入或更新 OTP 记录
- query := `
- INSERT INTO otp (openid, num, token)
- VALUES ($1, $2, $3)
- `
-
- _, err := h.DB.Exec(query, req.OpenID, req.Token, num)
- if err != nil {
- http.Error(w, "Failed to update or create OTP", http.StatusInternalServerError)
- return
- }
-
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("OTP updated or created successfully"))
-}
-
-func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
- openid := r.URL.Query().Get("openid")
- if openid == "" {
- http.Error(w, "未登录", http.StatusBadRequest)
- return
- }
-
- var otp OtpRequest
-
- err := h.DB.Get(&otp, "SELECT token, num, openid FROM otp WHERE openid=$1", openid)
- if err != nil {
- http.Error(w, "OTP not found", http.StatusNotFound)
- return
- }
-
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(otp)
-}
diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go
new file mode 100644
index 0000000..c1d99bb
--- /dev/null
+++ b/handlers/otp_handler.go
@@ -0,0 +1,114 @@
+package handlers
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/julienschmidt/httprouter"
+
+ "otpm/api"
+ "otpm/middleware"
+ "otpm/models"
+ "otpm/services"
+)
+
+// OTPHandler handles OTP-related HTTP requests
+type OTPHandler struct {
+ otpService *services.OTPService
+}
+
+// NewOTPHandler creates a new OTPHandler
+func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
+ return &OTPHandler{
+ otpService: otpService,
+ }
+}
+
+// Routes returns the routes for OTP operations
+func (h *OTPHandler) Routes() map[string]httprouter.Handle {
+ return map[string]httprouter.Handle{
+ "POST /api/otp": h.CreateOTP,
+ "GET /api/otps": h.ListOTPs,
+ "GET /api/otp/:id": h.GetOTP,
+ }
+}
+
+// CreateOTP handles the creation of a new OTP
+func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ // Get user ID from context
+ userID, ok := r.Context().Value(middleware.UserIDKey).(string)
+ if !ok {
+ api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
+ return
+ }
+
+ // Parse request body
+ var params models.OTPParams
+ if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
+ api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
+ return
+ }
+
+ // Validate request
+ if err := api.Validate.Struct(params); err != nil {
+ api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
+ return
+ }
+
+ // Create OTP
+ otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
+ if err != nil {
+ api.NewResponseWriter(w).WriteError(api.InternalError(err))
+ return
+ }
+
+ // Return response
+ api.NewResponseWriter(w).WriteSuccess(otp)
+}
+
+// ListOTPs handles listing all OTPs for a user
+func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ // Get user ID from context
+ userID, ok := r.Context().Value(middleware.UserIDKey).(string)
+ if !ok {
+ api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
+ return
+ }
+
+ // Get OTPs
+ otps, err := h.otpService.ListOTPs(r.Context(), userID)
+ if err != nil {
+ api.NewResponseWriter(w).WriteError(api.InternalError(err))
+ return
+ }
+
+ // Return response
+ api.NewResponseWriter(w).WriteSuccess(otps)
+}
+
+// GetOTP handles getting a specific OTP
+func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+ // Get user ID from context
+ userID, ok := r.Context().Value(middleware.UserIDKey).(string)
+ if !ok {
+ api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
+ return
+ }
+
+ // Get OTP ID from URL
+ otpID := ps.ByName("id")
+ if otpID == "" {
+ api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
+ return
+ }
+
+ // Get OTP
+ otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
+ if err != nil {
+ api.NewResponseWriter(w).WriteError(api.InternalError(err))
+ return
+ }
+
+ // Return response
+ api.NewResponseWriter(w).WriteSuccess(otp)
+}
diff --git a/logger/logger.go b/logger/logger.go
new file mode 100644
index 0000000..d7d76e4
--- /dev/null
+++ b/logger/logger.go
@@ -0,0 +1,204 @@
+package logger
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+// Level represents a log level
+type Level int
+
+const (
+ // DEBUG level
+ DEBUG Level = iota
+ // INFO level
+ INFO
+ // WARN level
+ WARN
+ // ERROR level
+ ERROR
+ // FATAL level
+ FATAL
+)
+
+// String returns the string representation of the log level
+func (l Level) String() string {
+ switch l {
+ case DEBUG:
+ return "DEBUG"
+ case INFO:
+ return "INFO"
+ case WARN:
+ return "WARN"
+ case ERROR:
+ return "ERROR"
+ case FATAL:
+ return "FATAL"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// Logger represents a logger
+type Logger struct {
+ level Level
+ output io.Writer
+}
+
+// contextKey is a type for context keys
+type contextKey string
+
+// requestIDKey is the key for request ID in context
+const requestIDKey = contextKey("request_id")
+
+// New creates a new logger
+func New(level Level, output io.Writer) *Logger {
+ if output == nil {
+ output = os.Stdout
+ }
+ return &Logger{
+ level: level,
+ output: output,
+ }
+}
+
+// WithLevel creates a new logger with the specified level
+func (l *Logger) WithLevel(level Level) *Logger {
+ return &Logger{
+ level: level,
+ output: l.output,
+ }
+}
+
+// WithOutput creates a new logger with the specified output
+func (l *Logger) WithOutput(output io.Writer) *Logger {
+ return &Logger{
+ level: l.level,
+ output: output,
+ }
+}
+
+// log logs a message with the specified level
+func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) {
+ if level < l.level {
+ return
+ }
+
+ // Get request ID from context
+ requestID := getRequestID(ctx)
+
+ // Get caller information
+ _, file, line, ok := runtime.Caller(2)
+ if !ok {
+ file = "unknown"
+ line = 0
+ }
+ // Extract just the filename
+ if idx := strings.LastIndex(file, "/"); idx >= 0 {
+ file = file[idx+1:]
+ }
+
+ // Format message
+ message := fmt.Sprintf(format, args...)
+
+ // Format log entry
+ timestamp := time.Now().Format(time.RFC3339)
+ logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n",
+ timestamp, level.String(), file, line, requestID, message)
+
+ // Write log entry
+ _, _ = l.output.Write([]byte(logEntry))
+
+ // Exit if fatal
+ if level == FATAL {
+ os.Exit(1)
+ }
+}
+
+// Debug logs a debug message
+func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
+ l.log(ctx, DEBUG, format, args...)
+}
+
+// Info logs an info message
+func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
+ l.log(ctx, INFO, format, args...)
+}
+
+// Warn logs a warning message
+func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
+ l.log(ctx, WARN, format, args...)
+}
+
+// Error logs an error message
+func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
+ l.log(ctx, ERROR, format, args...)
+}
+
+// Fatal logs a fatal message and exits
+func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
+ l.log(ctx, FATAL, format, args...)
+}
+
+// WithRequestID adds a request ID to the context
+func WithRequestID(ctx context.Context) context.Context {
+ requestID := uuid.New().String()
+ return context.WithValue(ctx, requestIDKey, requestID)
+}
+
+// GetRequestID gets the request ID from the context
+func GetRequestID(ctx context.Context) string {
+ return getRequestID(ctx)
+}
+
+// getRequestID gets the request ID from the context
+func getRequestID(ctx context.Context) string {
+ if ctx == nil {
+ return "-"
+ }
+ requestID, ok := ctx.Value(requestIDKey).(string)
+ if !ok {
+ return "-"
+ }
+ return requestID
+}
+
+// Default logger
+var defaultLogger = New(INFO, os.Stdout)
+
+// SetDefaultLogger sets the default logger
+func SetDefaultLogger(logger *Logger) {
+ defaultLogger = logger
+}
+
+// Debug logs a debug message using the default logger
+func Debug(ctx context.Context, format string, args ...interface{}) {
+ defaultLogger.Debug(ctx, format, args...)
+}
+
+// Info logs an info message using the default logger
+func Info(ctx context.Context, format string, args ...interface{}) {
+ defaultLogger.Info(ctx, format, args...)
+}
+
+// Warn logs a warning message using the default logger
+func Warn(ctx context.Context, format string, args ...interface{}) {
+ defaultLogger.Warn(ctx, format, args...)
+}
+
+// Error logs an error message using the default logger
+func Error(ctx context.Context, format string, args ...interface{}) {
+ defaultLogger.Error(ctx, format, args...)
+}
+
+// Fatal logs a fatal message and exits using the default logger
+func Fatal(ctx context.Context, format string, args ...interface{}) {
+ defaultLogger.Fatal(ctx, format, args...)
+}
diff --git a/metrics/metrics.go b/metrics/metrics.go
new file mode 100644
index 0000000..8e1d18a
--- /dev/null
+++ b/metrics/metrics.go
@@ -0,0 +1,193 @@
+package metrics
+
+import (
+ "fmt"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+)
+
+var (
+ // Default metrics
+ requestDuration = prometheus.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Name: "http_request_duration_seconds",
+ Help: "Duration of HTTP requests in seconds",
+ Buckets: prometheus.DefBuckets,
+ },
+ []string{"method", "path", "status"},
+ )
+
+ requestTotal = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "http_requests_total",
+ Help: "Total number of HTTP requests",
+ },
+ []string{"method", "path", "status"},
+ )
+
+ otpGenerationTotal = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "otp_generation_total",
+ Help: "Total number of OTP generations",
+ },
+ []string{"user_id", "otp_id"},
+ )
+
+ otpVerificationTotal = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "otp_verification_total",
+ Help: "Total number of OTP verifications",
+ },
+ []string{"user_id", "otp_id", "success"},
+ )
+
+ activeUsers = prometheus.NewGauge(
+ prometheus.GaugeOpts{
+ Name: "active_users",
+ Help: "Number of active users",
+ },
+ )
+
+ cacheHits = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "cache_hits_total",
+ Help: "Total number of cache hits",
+ },
+ []string{"cache"},
+ )
+
+ cacheMisses = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "cache_misses_total",
+ Help: "Total number of cache misses",
+ },
+ []string{"cache"},
+ )
+)
+
+func init() {
+ // Register metrics with prometheus
+ prometheus.MustRegister(
+ requestDuration,
+ requestTotal,
+ otpGenerationTotal,
+ otpVerificationTotal,
+ activeUsers,
+ cacheHits,
+ cacheMisses,
+ )
+}
+
+// MetricsService provides metrics functionality
+type MetricsService struct {
+ activeUsersMutex sync.RWMutex
+ activeUserIDs map[string]bool
+}
+
+// NewMetricsService creates a new MetricsService
+func NewMetricsService() *MetricsService {
+ return &MetricsService{
+ activeUserIDs: make(map[string]bool),
+ }
+}
+
+// Handler returns an HTTP handler for metrics
+func (s *MetricsService) Handler() http.Handler {
+ return promhttp.Handler()
+}
+
+// RecordRequest records metrics for an HTTP request
+func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) {
+ labels := prometheus.Labels{
+ "method": method,
+ "path": path,
+ "status": fmt.Sprintf("%d", status),
+ }
+
+ requestDuration.With(labels).Observe(duration.Seconds())
+ requestTotal.With(labels).Inc()
+}
+
+// RecordOTPGeneration records metrics for OTP generation
+func (s *MetricsService) RecordOTPGeneration(userID, otpID string) {
+ otpGenerationTotal.With(prometheus.Labels{
+ "user_id": userID,
+ "otp_id": otpID,
+ }).Inc()
+}
+
+// RecordOTPVerification records metrics for OTP verification
+func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) {
+ otpVerificationTotal.With(prometheus.Labels{
+ "user_id": userID,
+ "otp_id": otpID,
+ "success": fmt.Sprintf("%t", success),
+ }).Inc()
+}
+
+// RecordUserActivity records user activity
+func (s *MetricsService) RecordUserActivity(userID string) {
+ s.activeUsersMutex.Lock()
+ defer s.activeUsersMutex.Unlock()
+
+ if !s.activeUserIDs[userID] {
+ s.activeUserIDs[userID] = true
+ activeUsers.Inc()
+ }
+}
+
+// RecordUserInactivity records user inactivity
+func (s *MetricsService) RecordUserInactivity(userID string) {
+ s.activeUsersMutex.Lock()
+ defer s.activeUsersMutex.Unlock()
+
+ if s.activeUserIDs[userID] {
+ delete(s.activeUserIDs, userID)
+ activeUsers.Dec()
+ }
+}
+
+// RecordCacheHit records a cache hit
+func (s *MetricsService) RecordCacheHit(cache string) {
+ cacheHits.With(prometheus.Labels{
+ "cache": cache,
+ }).Inc()
+}
+
+// RecordCacheMiss records a cache miss
+func (s *MetricsService) RecordCacheMiss(cache string) {
+ cacheMisses.With(prometheus.Labels{
+ "cache": cache,
+ }).Inc()
+}
+
+// Middleware creates a middleware that records request metrics
+func (s *MetricsService) Middleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ start := time.Now()
+
+ // Create response writer that captures status code
+ rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
+
+ // Call next handler
+ next.ServeHTTP(rw, r)
+
+ // Record metrics
+ s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start))
+ })
+}
+
+// responseWriter wraps http.ResponseWriter to capture status code
+type responseWriter struct {
+ http.ResponseWriter
+ status int
+}
+
+func (rw *responseWriter) WriteHeader(code int) {
+ rw.status = code
+ rw.ResponseWriter.WriteHeader(code)
+}
diff --git a/middleware/middleware.go b/middleware/middleware.go
new file mode 100644
index 0000000..6afce13
--- /dev/null
+++ b/middleware/middleware.go
@@ -0,0 +1,353 @@
+package middleware
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "math/rand"
+ "net/http"
+ "runtime/debug"
+ "strings"
+ "time"
+
+ "github.com/golang-jwt/jwt"
+)
+
+// Response represents a standard API response
+type Response struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data interface{} `json:"data,omitempty"`
+}
+
+// ErrorResponse sends a JSON error response
+func ErrorResponse(w http.ResponseWriter, code int, message string) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(code)
+ json.NewEncoder(w).Encode(Response{
+ Code: code,
+ Message: message,
+ })
+}
+
+// SuccessResponse sends a JSON success response
+func SuccessResponse(w http.ResponseWriter, data interface{}) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(Response{
+ Code: http.StatusOK,
+ Message: "success",
+ Data: data,
+ })
+}
+
+// Logger is a middleware that logs request details with structured format
+func Logger(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ start := time.Now()
+ requestID := r.Header.Get("X-Request-ID")
+ if requestID == "" {
+ requestID = generateRequestID()
+ r.Header.Set("X-Request-ID", requestID)
+ }
+
+ // Create a custom response writer to capture status code
+ rw := &responseWriter{
+ ResponseWriter: w,
+ status: http.StatusOK,
+ }
+
+ // Process request
+ next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
+
+ // Log structured request details
+ log.Printf(
+ "method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
+ r.Method,
+ r.URL.Path,
+ rw.status,
+ time.Since(start).String(),
+ r.RemoteAddr,
+ requestID,
+ )
+ })
+}
+
+// generateRequestID creates a unique request identifier
+func generateRequestID() string {
+ return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
+}
+
+// randomString generates a random string of given length
+func randomString(length int) string {
+ const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ b := make([]byte, length)
+ for i := range b {
+ b[i] = charset[rand.Intn(len(charset))]
+ }
+ return string(b)
+}
+
+// Recover is a middleware that recovers from panics with detailed logging
+func Recover(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() {
+ if err := recover(); err != nil {
+ // Get request ID from context
+ requestID := ""
+ if ctx := r.Context(); ctx != nil {
+ if id, ok := ctx.Value("request_id").(string); ok {
+ requestID = id
+ }
+ }
+
+ // Log error with stack trace and request context
+ log.Printf(
+ "panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
+ err,
+ requestID,
+ r.Method,
+ r.URL.Path,
+ r.RemoteAddr,
+ debug.Stack(),
+ )
+
+ // Determine error type
+ var message string
+ var status int
+
+ switch e := err.(type) {
+ case error:
+ message = e.Error()
+ if isClientError(e) {
+ status = http.StatusBadRequest
+ } else {
+ status = http.StatusInternalServerError
+ }
+ case string:
+ message = e
+ status = http.StatusInternalServerError
+ default:
+ message = "Internal Server Error"
+ status = http.StatusInternalServerError
+ }
+
+ ErrorResponse(w, status, message)
+ }
+ }()
+ next.ServeHTTP(w, r)
+ })
+}
+
+// isClientError checks if error should be treated as client error
+func isClientError(err error) bool {
+ // Add more client error types as needed
+ return strings.Contains(err.Error(), "validation") ||
+ strings.Contains(err.Error(), "invalid") ||
+ strings.Contains(err.Error(), "missing")
+}
+
+// CORS is a middleware that handles CORS
+func CORS(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
+
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+}
+
+// Timeout is a middleware that safely handles request timeouts
+func Timeout(duration time.Duration) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), duration)
+ defer cancel()
+
+ // Use buffered channels to prevent goroutine leaks
+ done := make(chan struct{}, 1)
+ panicChan := make(chan interface{}, 1)
+
+ // Track request processing in goroutine
+ go func() {
+ defer func() {
+ if p := recover(); p != nil {
+ panicChan <- p
+ }
+ }()
+ next.ServeHTTP(w, r.WithContext(ctx))
+ done <- struct{}{}
+ }()
+
+ // Wait for completion, timeout or panic
+ select {
+ case <-done:
+ return
+ case p := <-panicChan:
+ panic(p) // Re-throw panic to be caught by Recover middleware
+ case <-ctx.Done():
+ // Get request context for logging
+ requestID := ""
+ if ctx := r.Context(); ctx != nil {
+ if id, ok := ctx.Value("request_id").(string); ok {
+ requestID = id
+ }
+ }
+
+ // Log timeout details
+ log.Printf(
+ "request_timeout: request_id=%s method=%s path=%s timeout=%s",
+ requestID,
+ r.Method,
+ r.URL.Path,
+ duration.String(),
+ )
+
+ // Send timeout response
+ ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
+ "Request timed out after %s", duration.String(),
+ ))
+ }
+ })
+ }
+}
+
+// Auth is a middleware that validates JWT tokens with enhanced security
+func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Get request ID for logging
+ requestID := ""
+ if ctx := r.Context(); ctx != nil {
+ if id, ok := ctx.Value("request_id").(string); ok {
+ requestID = id
+ }
+ }
+
+ // Get token from Authorization header
+ authHeader := r.Header.Get("Authorization")
+ if authHeader == "" {
+ log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
+ ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
+ return
+ }
+
+ // Validate header format
+ parts := strings.Split(authHeader, " ")
+ if len(parts) != 2 || parts[0] != "Bearer" {
+ log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
+ ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer '")
+ return
+ }
+
+ tokenString := parts[1]
+
+ // Parse and validate token
+ token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
+ // Validate signing method
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+ return []byte(jwtSecret), nil
+ })
+
+ if err != nil {
+ log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
+ ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
+ return
+ }
+
+ if !token.Valid {
+ log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
+ ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
+ return
+ }
+
+ // Validate claims
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
+ ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
+ return
+ }
+
+ // Check required claims
+ userID, ok := claims["user_id"].(string)
+ if !ok || userID == "" {
+ log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
+ ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
+ return
+ }
+
+ // Check token expiration
+ if exp, ok := claims["exp"].(float64); ok {
+ if time.Now().Unix() > int64(exp) {
+ log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
+ ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
+ return
+ }
+ }
+
+ // Check required roles if specified
+ if len(requiredRoles) > 0 {
+ roles, ok := claims["roles"].([]interface{})
+ if !ok {
+ log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
+ ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
+ return
+ }
+
+ hasRequiredRole := false
+ for _, requiredRole := range requiredRoles {
+ for _, role := range roles {
+ if r, ok := role.(string); ok && r == requiredRole {
+ hasRequiredRole = true
+ break
+ }
+ }
+ }
+
+ if !hasRequiredRole {
+ log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
+ ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
+ return
+ }
+ }
+
+ // Add claims to context
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, "user_id", userID)
+ ctx = context.WithValue(ctx, "claims", claims)
+
+ log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+ }
+}
+
+// responseWriter is a custom response writer that captures the status code
+type responseWriter struct {
+ http.ResponseWriter
+ status int
+}
+
+func (rw *responseWriter) WriteHeader(code int) {
+ rw.status = code
+ rw.ResponseWriter.WriteHeader(code)
+}
+
+// GetUserID gets the user ID from the request context
+func GetUserID(r *http.Request) (string, error) {
+ userID, ok := r.Context().Value("user_id").(string)
+ if !ok {
+ return "", fmt.Errorf("user ID not found in context")
+ }
+ return userID, nil
+}
diff --git a/miniprogram-example/app.js b/miniprogram-example/app.js
new file mode 100644
index 0000000..34eba08
--- /dev/null
+++ b/miniprogram-example/app.js
@@ -0,0 +1,50 @@
+// app.js
+App({
+ onLaunch() {
+ // 检查更新
+ if (wx.canIUse('getUpdateManager')) {
+ const updateManager = wx.getUpdateManager();
+ updateManager.onCheckForUpdate(function (res) {
+ if (res.hasUpdate) {
+ updateManager.onUpdateReady(function () {
+ wx.showModal({
+ title: '更新提示',
+ content: '新版本已经准备好,是否重启应用?',
+ success: function (res) {
+ if (res.confirm) {
+ updateManager.applyUpdate();
+ }
+ }
+ });
+ });
+
+ updateManager.onUpdateFailed(function () {
+ wx.showModal({
+ title: '更新提示',
+ content: '新版本下载失败,请检查网络后重试',
+ showCancel: false
+ });
+ });
+ }
+ });
+ }
+
+ // 获取系统信息
+ try {
+ const systemInfo = wx.getSystemInfoSync();
+ this.globalData.systemInfo = systemInfo;
+
+ // 计算安全区域
+ const { screenHeight, safeArea } = systemInfo;
+ this.globalData.safeAreaBottom = screenHeight - safeArea.bottom;
+ } catch (e) {
+ console.error('获取系统信息失败', e);
+ }
+ },
+
+ globalData: {
+ userInfo: null,
+ systemInfo: {},
+ safeAreaBottom: 0
+ }
+});
\ No newline at end of file
diff --git a/miniprogram-example/app.json b/miniprogram-example/app.json
new file mode 100644
index 0000000..89f79ae
--- /dev/null
+++ b/miniprogram-example/app.json
@@ -0,0 +1,23 @@
+{
+ "pages": [
+ "pages/login/login",
+ "pages/otp-list/index",
+ "pages/otp-add/index"
+ ],
+ "window": {
+ "backgroundTextStyle": "light",
+ "navigationBarBackgroundColor": "#fff",
+ "navigationBarTitleText": "OTPM",
+ "navigationBarTextStyle": "black",
+ "backgroundColor": "#F8F8F8"
+ },
+ "permission": {
+ "scope.camera": {
+ "desc": "需要使用相机扫描二维码"
+ }
+ },
+ "usingComponents": {},
+ "style": "v2",
+ "sitemapLocation": "sitemap.json",
+ "lazyCodeLoading": "requiredComponents"
+}
\ No newline at end of file
diff --git a/miniprogram-example/app.wxss b/miniprogram-example/app.wxss
new file mode 100644
index 0000000..fb73b4a
--- /dev/null
+++ b/miniprogram-example/app.wxss
@@ -0,0 +1,238 @@
+/**app.wxss**/
+page {
+ --primary-color: #1890ff;
+ --danger-color: #ff4d4f;
+ --success-color: #52c41a;
+ --warning-color: #faad14;
+ --text-color: #333333;
+ --text-color-secondary: #666666;
+ --text-color-light: #999999;
+ --border-color: #e8e8e8;
+ --background-color: #f8f8f8;
+ --border-radius: 8rpx;
+ --safe-area-bottom: env(safe-area-inset-bottom);
+
+ font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica,
+ Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei',
+ sans-serif;
+ font-size: 28rpx;
+ line-height: 1.5;
+ color: var(--text-color);
+ background-color: var(--background-color);
+}
+
+/* 清除默认样式 */
+button {
+ padding: 0;
+ margin: 0;
+ background: none;
+ border: none;
+ text-align: left;
+ line-height: inherit;
+ overflow: visible;
+}
+
+button::after {
+ border: none;
+}
+
+/* 通用样式类 */
+.container {
+ min-height: 100vh;
+ box-sizing: border-box;
+}
+
+.safe-area-bottom {
+ padding-bottom: var(--safe-area-bottom);
+}
+
+.flex-center {
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.flex-between {
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+}
+
+.flex-column {
+ display: flex;
+ flex-direction: column;
+}
+
+.text-primary {
+ color: var(--primary-color);
+}
+
+.text-danger {
+ color: var(--danger-color);
+}
+
+.text-success {
+ color: var(--success-color);
+}
+
+.text-warning {
+ color: var(--warning-color);
+}
+
+.text-secondary {
+ color: var(--text-color-secondary);
+}
+
+.text-light {
+ color: var(--text-color-light);
+}
+
+.text-center {
+ text-align: center;
+}
+
+.text-left {
+ text-align: left;
+}
+
+.text-right {
+ text-align: right;
+}
+
+.text-bold {
+ font-weight: bold;
+}
+
+.text-ellipsis {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+}
+
+.bg-white {
+ background-color: #ffffff;
+}
+
+.bg-primary {
+ background-color: var(--primary-color);
+}
+
+.bg-danger {
+ background-color: var(--danger-color);
+}
+
+.bg-success {
+ background-color: var(--success-color);
+}
+
+.bg-warning {
+ background-color: var(--warning-color);
+}
+
+.shadow {
+ box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
+}
+
+.rounded {
+ border-radius: var(--border-radius);
+}
+
+.border {
+ border: 2rpx solid var(--border-color);
+}
+
+.border-top {
+ border-top: 2rpx solid var(--border-color);
+}
+
+.border-bottom {
+ border-bottom: 2rpx solid var(--border-color);
+}
+
+/* 动画类 */
+.fade-in {
+ animation: fadeIn 0.3s ease-in-out;
+}
+
+.fade-out {
+ animation: fadeOut 0.3s ease-in-out;
+}
+
+@keyframes fadeIn {
+ from {
+ opacity: 0;
+ }
+ to {
+ opacity: 1;
+ }
+}
+
+@keyframes fadeOut {
+ from {
+ opacity: 1;
+ }
+ to {
+ opacity: 0;
+ }
+}
+
+/* 间距类 */
+.m-0 { margin: 0; }
+.m-1 { margin: 10rpx; }
+.m-2 { margin: 20rpx; }
+.m-3 { margin: 30rpx; }
+.m-4 { margin: 40rpx; }
+
+.mt-0 { margin-top: 0; }
+.mt-1 { margin-top: 10rpx; }
+.mt-2 { margin-top: 20rpx; }
+.mt-3 { margin-top: 30rpx; }
+.mt-4 { margin-top: 40rpx; }
+
+.mb-0 { margin-bottom: 0; }
+.mb-1 { margin-bottom: 10rpx; }
+.mb-2 { margin-bottom: 20rpx; }
+.mb-3 { margin-bottom: 30rpx; }
+.mb-4 { margin-bottom: 40rpx; }
+
+.ml-0 { margin-left: 0; }
+.ml-1 { margin-left: 10rpx; }
+.ml-2 { margin-left: 20rpx; }
+.ml-3 { margin-left: 30rpx; }
+.ml-4 { margin-left: 40rpx; }
+
+.mr-0 { margin-right: 0; }
+.mr-1 { margin-right: 10rpx; }
+.mr-2 { margin-right: 20rpx; }
+.mr-3 { margin-right: 30rpx; }
+.mr-4 { margin-right: 40rpx; }
+
+.p-0 { padding: 0; }
+.p-1 { padding: 10rpx; }
+.p-2 { padding: 20rpx; }
+.p-3 { padding: 30rpx; }
+.p-4 { padding: 40rpx; }
+
+.pt-0 { padding-top: 0; }
+.pt-1 { padding-top: 10rpx; }
+.pt-2 { padding-top: 20rpx; }
+.pt-3 { padding-top: 30rpx; }
+.pt-4 { padding-top: 40rpx; }
+
+.pb-0 { padding-bottom: 0; }
+.pb-1 { padding-bottom: 10rpx; }
+.pb-2 { padding-bottom: 20rpx; }
+.pb-3 { padding-bottom: 30rpx; }
+.pb-4 { padding-bottom: 40rpx; }
+
+.pl-0 { padding-left: 0; }
+.pl-1 { padding-left: 10rpx; }
+.pl-2 { padding-left: 20rpx; }
+.pl-3 { padding-left: 30rpx; }
+.pl-4 { padding-left: 40rpx; }
+
+.pr-0 { padding-right: 0; }
+.pr-1 { padding-right: 10rpx; }
+.pr-2 { padding-right: 20rpx; }
+.pr-3 { padding-right: 30rpx; }
+.pr-4 { padding-right: 40rpx; }
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.js b/miniprogram-example/pages/login/login.js
new file mode 100644
index 0000000..e4ad173
--- /dev/null
+++ b/miniprogram-example/pages/login/login.js
@@ -0,0 +1,48 @@
+// login.js
+import { wxLogin } from '../../services/auth';
+
+Page({
+ data: {
+ loading: false
+ },
+
+ onLoad() {
+ // 页面加载时检查是否已经登录
+ const token = wx.getStorageSync('token');
+ if (token) {
+ this.redirectToHome();
+ }
+ },
+
+ // 处理登录按钮点击
+ handleLogin() {
+ if (this.data.loading) return;
+
+ this.setData({ loading: true });
+
+ wxLogin()
+ .then(() => {
+ wx.showToast({
+ title: '登录成功',
+ icon: 'success'
+ });
+ this.redirectToHome();
+ })
+ .catch(err => {
+ wx.showToast({
+ title: err.message || '登录失败',
+ icon: 'none'
+ });
+ })
+ .finally(() => {
+ this.setData({ loading: false });
+ });
+ },
+
+ // 跳转到首页
+ redirectToHome() {
+ wx.reLaunch({
+ url: '/pages/otp-list/index'
+ });
+ }
+});
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.json b/miniprogram-example/pages/login/login.json
new file mode 100644
index 0000000..8835af0
--- /dev/null
+++ b/miniprogram-example/pages/login/login.json
@@ -0,0 +1,3 @@
+{
+ "usingComponents": {}
+}
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.wxml b/miniprogram-example/pages/login/login.wxml
new file mode 100644
index 0000000..d73a711
--- /dev/null
+++ b/miniprogram-example/pages/login/login.wxml
@@ -0,0 +1,30 @@
+
+
+
+
+ OTPM 小程序
+
+
+
+ 欢迎使用 OTPM
+ 一次性密码管理工具
+
+
+
+
+ 登录即表示您同意
+ 《隐私政策》
+
+
+
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.wxss b/miniprogram-example/pages/login/login.wxss
new file mode 100644
index 0000000..26cece6
--- /dev/null
+++ b/miniprogram-example/pages/login/login.wxss
@@ -0,0 +1,97 @@
+/* login.wxss */
+.container {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: space-between;
+ height: 100vh;
+ padding: 60rpx 40rpx;
+ box-sizing: border-box;
+ background-color: #f8f8f8;
+}
+
+.logo-container {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ margin-top: 80rpx;
+}
+
+.logo {
+ width: 180rpx;
+ height: 180rpx;
+ margin-bottom: 20rpx;
+}
+
+.app-name {
+ font-size: 36rpx;
+ font-weight: bold;
+ color: #333;
+}
+
+.login-container {
+ width: 100%;
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ margin-bottom: 100rpx;
+}
+
+.login-title {
+ font-size: 48rpx;
+ font-weight: bold;
+ color: #333;
+ margin-bottom: 20rpx;
+}
+
+.login-subtitle {
+ font-size: 28rpx;
+ color: #666;
+ margin-bottom: 80rpx;
+}
+
+.login-button {
+ width: 80%;
+ height: 88rpx;
+ border-radius: 44rpx;
+ font-size: 32rpx;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.login-button.loading {
+ background-color: #8cc4ff;
+}
+
+.loading-container {
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.loading-icon {
+ width: 36rpx;
+ height: 36rpx;
+ margin-right: 10rpx;
+ border: 4rpx solid #ffffff;
+ border-radius: 50%;
+ border-top-color: transparent;
+ animation: spin 1s linear infinite;
+}
+
+@keyframes spin {
+ 0% { transform: rotate(0deg); }
+ 100% { transform: rotate(360deg); }
+}
+
+.privacy-policy {
+ margin-top: 40rpx;
+ font-size: 24rpx;
+ color: #999;
+}
+
+.policy-link {
+ color: #1890ff;
+ display: inline;
+}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.js b/miniprogram-example/pages/otp-add/index.js
new file mode 100644
index 0000000..e89c623
--- /dev/null
+++ b/miniprogram-example/pages/otp-add/index.js
@@ -0,0 +1,169 @@
+// otp-add/index.js
+import { createOTP } from '../../services/otp';
+
+Page({
+ data: {
+ form: {
+ name: '',
+ issuer: '',
+ secret: '',
+ algorithm: 'SHA1',
+ digits: 6,
+ period: 30
+ },
+ algorithms: ['SHA1', 'SHA256', 'SHA512'],
+ digitOptions: [6, 8],
+ periodOptions: [30, 60],
+ submitting: false,
+ scanMode: false
+ },
+
+ // 处理输入变化
+ handleInputChange(e) {
+ const { field } = e.currentTarget.dataset;
+ const { value } = e.detail;
+
+ this.setData({
+ [`form.${field}`]: value
+ });
+ },
+
+ // 处理选择器变化
+ handlePickerChange(e) {
+ const { field } = e.currentTarget.dataset;
+ const { value } = e.detail;
+
+ const options = this.data[`${field}Options`] || this.data[field];
+ const selectedValue = options[value];
+
+ this.setData({
+ [`form.${field}`]: selectedValue
+ });
+ },
+
+ // 扫描二维码
+ handleScanQRCode() {
+ this.setData({ scanMode: true });
+
+ wx.scanCode({
+ scanType: ['qrCode'],
+ success: (res) => {
+ try {
+ // 解析otpauth://协议的URL
+ const url = res.result;
+ if (url.startsWith('otpauth://totp/')) {
+ const parsedUrl = new URL(url);
+ const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠
+
+ // 解析路径中的issuer和name
+ let issuer = '';
+ let name = path;
+
+ if (path.includes(':')) {
+ const parts = path.split(':');
+ issuer = parts[0];
+ name = parts[1];
+ }
+
+ // 从查询参数中获取其他信息
+ const secret = parsedUrl.searchParams.get('secret') || '';
+ const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1';
+ const digits = parseInt(parsedUrl.searchParams.get('digits') || '6');
+ const period = parseInt(parsedUrl.searchParams.get('period') || '30');
+
+ // 如果查询参数中有issuer,优先使用
+ if (parsedUrl.searchParams.get('issuer')) {
+ issuer = parsedUrl.searchParams.get('issuer');
+ }
+
+ this.setData({
+ form: {
+ name,
+ issuer,
+ secret,
+ algorithm,
+ digits,
+ period
+ }
+ });
+
+ wx.showToast({
+ title: '二维码解析成功',
+ icon: 'success'
+ });
+ } else {
+ wx.showToast({
+ title: '不支持的二维码格式',
+ icon: 'none'
+ });
+ }
+ } catch (err) {
+ wx.showToast({
+ title: '二维码解析失败',
+ icon: 'none'
+ });
+ }
+ },
+ fail: () => {
+ wx.showToast({
+ title: '扫描取消',
+ icon: 'none'
+ });
+ },
+ complete: () => {
+ this.setData({ scanMode: false });
+ }
+ });
+ },
+
+ // 提交表单
+ handleSubmit() {
+ const { form } = this.data;
+
+ // 表单验证
+ if (!form.name) {
+ wx.showToast({
+ title: '请输入名称',
+ icon: 'none'
+ });
+ return;
+ }
+
+ if (!form.secret) {
+ wx.showToast({
+ title: '请输入密钥',
+ icon: 'none'
+ });
+ return;
+ }
+
+ this.setData({ submitting: true });
+
+ createOTP(form)
+ .then(() => {
+ wx.showToast({
+ title: '添加成功',
+ icon: 'success'
+ });
+
+ // 返回上一页
+ setTimeout(() => {
+ wx.navigateBack();
+ }, 1500);
+ })
+ .catch(err => {
+ wx.showToast({
+ title: err.message || '添加失败',
+ icon: 'none'
+ });
+ })
+ .finally(() => {
+ this.setData({ submitting: false });
+ });
+ },
+
+ // 取消
+ handleCancel() {
+ wx.navigateBack();
+ }
+});
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.json b/miniprogram-example/pages/otp-add/index.json
new file mode 100644
index 0000000..8835af0
--- /dev/null
+++ b/miniprogram-example/pages/otp-add/index.json
@@ -0,0 +1,3 @@
+{
+ "usingComponents": {}
+}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.wxml b/miniprogram-example/pages/otp-add/index.wxml
new file mode 100644
index 0000000..e5a2b01
--- /dev/null
+++ b/miniprogram-example/pages/otp-add/index.wxml
@@ -0,0 +1,119 @@
+
+
+
+
+
+
+ 名称 *
+
+
+
+
+ 发行方
+
+
+
+
+ 密钥 *
+
+
+
+ 🔍
+
+
+
+
+
+
+
+
+ 算法
+
+
+ {{form.algorithm}}
+ ▼
+
+
+
+
+
+
+ 位数
+
+
+ {{form.digits}}
+ ▼
+
+
+
+
+
+ 周期(秒)
+
+
+ {{form.period}}
+ ▼
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.wxss b/miniprogram-example/pages/otp-add/index.wxss
new file mode 100644
index 0000000..3906a97
--- /dev/null
+++ b/miniprogram-example/pages/otp-add/index.wxss
@@ -0,0 +1,176 @@
+/* otp-add/index.wxss */
+.container {
+ min-height: 100vh;
+ background-color: #f8f8f8;
+ padding: 0 0 40rpx 0;
+}
+
+.header {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ padding: 40rpx 32rpx;
+ background-color: #ffffff;
+ position: sticky;
+ top: 0;
+ z-index: 100;
+ box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
+}
+
+.title {
+ font-size: 36rpx;
+ font-weight: bold;
+ color: #333333;
+}
+
+.form-container {
+ background-color: #ffffff;
+ padding: 32rpx;
+ margin: 32rpx;
+ border-radius: 16rpx;
+ box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
+}
+
+.form-group {
+ margin-bottom: 32rpx;
+}
+
+.form-row {
+ display: flex;
+ justify-content: space-between;
+}
+
+.form-group.half {
+ width: 48%;
+}
+
+.form-label {
+ display: block;
+ font-size: 28rpx;
+ color: #666666;
+ margin-bottom: 12rpx;
+}
+
+.required {
+ color: #ff4d4f;
+}
+
+.form-input {
+ width: 100%;
+ height: 80rpx;
+ background-color: #f5f5f5;
+ border-radius: 8rpx;
+ padding: 0 24rpx;
+ font-size: 28rpx;
+ color: #333333;
+ box-sizing: border-box;
+}
+
+.secret-input-container {
+ position: relative;
+}
+
+.scan-button {
+ position: absolute;
+ right: 20rpx;
+ top: 50%;
+ transform: translateY(-50%);
+ width: 60rpx;
+ height: 60rpx;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.scan-icon {
+ font-size: 40rpx;
+ color: #1890ff;
+}
+
+.scanning-indicator {
+ position: absolute;
+ right: 20rpx;
+ top: 50%;
+ transform: translateY(-50%);
+ width: 40rpx;
+ height: 40rpx;
+}
+
+.scanning-spinner {
+ width: 40rpx;
+ height: 40rpx;
+ border: 4rpx solid #f3f3f3;
+ border-top: 4rpx solid #1890ff;
+ border-radius: 50%;
+ animation: spin 1s linear infinite;
+}
+
+@keyframes spin {
+ 0% { transform: rotate(0deg); }
+ 100% { transform: rotate(360deg); }
+}
+
+.picker-view {
+ width: 100%;
+ height: 80rpx;
+ background-color: #f5f5f5;
+ border-radius: 8rpx;
+ padding: 0 24rpx;
+ font-size: 28rpx;
+ color: #333333;
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+ box-sizing: border-box;
+}
+
+.picker-arrow {
+ font-size: 24rpx;
+ color: #999999;
+}
+
+.button-group {
+ display: flex;
+ justify-content: space-between;
+ padding: 32rpx;
+}
+
+.cancel-button, .submit-button {
+ width: 48%;
+ height: 88rpx;
+ border-radius: 44rpx;
+ font-size: 32rpx;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.cancel-button {
+ background-color: #f5f5f5;
+ color: #666666;
+}
+
+.submit-button {
+ background-color: #1890ff;
+ color: #ffffff;
+}
+
+.submit-button.loading {
+ background-color: #8cc4ff;
+}
+
+.loading-container {
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.loading-icon {
+ width: 36rpx;
+ height: 36rpx;
+ margin-right: 10rpx;
+ border: 4rpx solid #ffffff;
+ border-radius: 50%;
+ border-top-color: transparent;
+ animation: spin 1s linear infinite;
+}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.js b/miniprogram-example/pages/otp-list/index.js
new file mode 100644
index 0000000..9ed6fa1
--- /dev/null
+++ b/miniprogram-example/pages/otp-list/index.js
@@ -0,0 +1,213 @@
+// otp-list/index.js
+import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp';
+import { checkLoginStatus } from '../../services/auth';
+
+Page({
+ data: {
+ otpList: [],
+ loading: true,
+ refreshing: false
+ },
+
+ onLoad() {
+ this.checkLogin();
+ },
+
+ onShow() {
+ // 每次页面显示时刷新OTP列表
+ if (!this.data.loading) {
+ this.fetchOTPList();
+ }
+ },
+
+ // 下拉刷新
+ onPullDownRefresh() {
+ this.setData({ refreshing: true });
+ this.fetchOTPList().finally(() => {
+ wx.stopPullDownRefresh();
+ this.setData({ refreshing: false });
+ });
+ },
+
+ // 检查登录状态
+ checkLogin() {
+ checkLoginStatus().then(isLoggedIn => {
+ if (isLoggedIn) {
+ this.fetchOTPList();
+ } else {
+ wx.redirectTo({
+ url: '/pages/login/login'
+ });
+ }
+ });
+ },
+
+ // 获取OTP列表
+ fetchOTPList() {
+ this.setData({ loading: true });
+ return getOTPList()
+ .then(res => {
+ if (res.data && Array.isArray(res.data)) {
+ this.setData({
+ otpList: res.data,
+ loading: false
+ });
+
+ // 获取每个OTP的当前验证码
+ this.refreshOTPCodes();
+ }
+ })
+ .catch(err => {
+ wx.showToast({
+ title: '获取OTP列表失败',
+ icon: 'none'
+ });
+ this.setData({ loading: false });
+ });
+ },
+
+ // 刷新所有OTP的验证码
+ refreshOTPCodes() {
+ const { otpList } = this.data;
+
+ // 为每个OTP获取当前验证码
+ const promises = otpList.map(otp => {
+ return getOTPCode(otp.id)
+ .then(res => {
+ if (res.data && res.data.code) {
+ return {
+ id: otp.id,
+ code: res.data.code,
+ expiresIn: res.data.expires_in || 30
+ };
+ }
+ return null;
+ })
+ .catch(() => null);
+ });
+
+ Promise.all(promises).then(results => {
+ const updatedList = [...this.data.otpList];
+
+ results.forEach(result => {
+ if (result) {
+ const index = updatedList.findIndex(otp => otp.id === result.id);
+ if (index !== -1) {
+ updatedList[index] = {
+ ...updatedList[index],
+ currentCode: result.code,
+ expiresIn: result.expiresIn
+ };
+ }
+ }
+ });
+
+ this.setData({ otpList: updatedList });
+
+ // 设置定时器,每秒更新倒计时
+ this.startCountdown();
+ });
+ },
+
+ // 开始倒计时
+ startCountdown() {
+ // 清除之前的定时器
+ if (this.countdownTimer) {
+ clearInterval(this.countdownTimer);
+ }
+
+ // 创建新的定时器,每秒更新一次
+ this.countdownTimer = setInterval(() => {
+ const { otpList } = this.data;
+ let needRefresh = false;
+
+ const updatedList = otpList.map(otp => {
+ if (!otp.countdown) {
+ otp.countdown = otp.expiresIn || 30;
+ }
+
+ otp.countdown -= 1;
+
+ // 如果倒计时结束,标记需要刷新
+ if (otp.countdown <= 0) {
+ needRefresh = true;
+ }
+
+ return otp;
+ });
+
+ this.setData({ otpList: updatedList });
+
+ // 如果有OTP需要刷新,重新获取验证码
+ if (needRefresh) {
+ this.refreshOTPCodes();
+ }
+ }, 1000);
+ },
+
+ // 添加新的OTP
+ handleAddOTP() {
+ wx.navigateTo({
+ url: '/pages/otp-add/index'
+ });
+ },
+
+ // 编辑OTP
+ handleEditOTP(e) {
+ const { id } = e.currentTarget.dataset;
+ wx.navigateTo({
+ url: `/pages/otp-edit/index?id=${id}`
+ });
+ },
+
+ // 删除OTP
+ handleDeleteOTP(e) {
+ const { id, name } = e.currentTarget.dataset;
+
+ wx.showModal({
+ title: '确认删除',
+ content: `确定要删除 ${name} 吗?`,
+ confirmColor: '#ff4d4f',
+ success: (res) => {
+ if (res.confirm) {
+ deleteOTP(id)
+ .then(() => {
+ wx.showToast({
+ title: '删除成功',
+ icon: 'success'
+ });
+ this.fetchOTPList();
+ })
+ .catch(err => {
+ wx.showToast({
+ title: '删除失败',
+ icon: 'none'
+ });
+ });
+ }
+ }
+ });
+ },
+
+ // 复制验证码
+ handleCopyCode(e) {
+ const { code } = e.currentTarget.dataset;
+
+ wx.setClipboardData({
+ data: code,
+ success: () => {
+ wx.showToast({
+ title: '验证码已复制',
+ icon: 'success'
+ });
+ }
+ });
+ },
+
+ onUnload() {
+ // 页面卸载时清除定时器
+ if (this.countdownTimer) {
+ clearInterval(this.countdownTimer);
+ }
+ }
+});
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.json b/miniprogram-example/pages/otp-list/index.json
new file mode 100644
index 0000000..8835af0
--- /dev/null
+++ b/miniprogram-example/pages/otp-list/index.json
@@ -0,0 +1,3 @@
+{
+ "usingComponents": {}
+}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.wxml b/miniprogram-example/pages/otp-list/index.wxml
new file mode 100644
index 0000000..27cb5ea
--- /dev/null
+++ b/miniprogram-example/pages/otp-list/index.wxml
@@ -0,0 +1,59 @@
+
+
+
+
+
+
+
+ 加载中...
+
+
+
+
+
+
+
+
+ {{item.name}}
+ {{item.issuer}}
+
+
+
+ {{item.currentCode || '******'}}
+ 点击复制
+
+
+
+
+ {{item.countdown || 0}}s
+
+
+
+
+
+ ✎
+
+
+ ✕
+
+
+
+
+
+
+
+
+ 暂无OTP,点击右上角添加
+
+
+
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.wxss b/miniprogram-example/pages/otp-list/index.wxss
new file mode 100644
index 0000000..40436d5
--- /dev/null
+++ b/miniprogram-example/pages/otp-list/index.wxss
@@ -0,0 +1,201 @@
+/* otp-list/index.wxss */
+.container {
+ min-height: 100vh;
+ background-color: #f8f8f8;
+ padding: 0 0 40rpx 0;
+}
+
+.header {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ padding: 40rpx 32rpx;
+ background-color: #ffffff;
+ position: sticky;
+ top: 0;
+ z-index: 100;
+ box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
+}
+
+.title {
+ font-size: 36rpx;
+ font-weight: bold;
+ color: #333333;
+}
+
+.add-button {
+ width: 64rpx;
+ height: 64rpx;
+ border-radius: 32rpx;
+ background-color: #1890ff;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3);
+}
+
+.add-icon {
+ color: #ffffff;
+ font-size: 40rpx;
+ line-height: 1;
+}
+
+/* 加载状态 */
+.loading-container {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ padding: 120rpx 0;
+}
+
+.loading-spinner {
+ width: 64rpx;
+ height: 64rpx;
+ border: 6rpx solid #f3f3f3;
+ border-top: 6rpx solid #1890ff;
+ border-radius: 50%;
+ animation: spin 1s linear infinite;
+}
+
+.loading-text {
+ margin-top: 20rpx;
+ font-size: 28rpx;
+ color: #999999;
+}
+
+@keyframes spin {
+ 0% { transform: rotate(0deg); }
+ 100% { transform: rotate(360deg); }
+}
+
+/* OTP列表 */
+.otp-list {
+ padding: 20rpx 32rpx;
+}
+
+.otp-item {
+ background-color: #ffffff;
+ border-radius: 16rpx;
+ padding: 32rpx;
+ margin-bottom: 20rpx;
+ display: flex;
+ justify-content: space-between;
+ box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
+}
+
+.otp-info {
+ flex: 1;
+ margin-right: 20rpx;
+}
+
+.otp-name-row {
+ display: flex;
+ align-items: center;
+ margin-bottom: 16rpx;
+}
+
+.otp-name {
+ font-size: 32rpx;
+ font-weight: bold;
+ color: #333333;
+ margin-right: 16rpx;
+}
+
+.otp-issuer {
+ font-size: 24rpx;
+ color: #666666;
+ background-color: #f5f5f5;
+ padding: 4rpx 12rpx;
+ border-radius: 8rpx;
+}
+
+.otp-code-row {
+ display: flex;
+ align-items: center;
+ margin-bottom: 20rpx;
+}
+
+.otp-code {
+ font-size: 44rpx;
+ font-family: monospace;
+ font-weight: bold;
+ color: #1890ff;
+ letter-spacing: 4rpx;
+ margin-right: 16rpx;
+}
+
+.copy-hint {
+ font-size: 24rpx;
+ color: #999999;
+}
+
+.otp-countdown {
+ position: relative;
+ width: 100%;
+}
+
+.countdown-text {
+ position: absolute;
+ right: 0;
+ top: -30rpx;
+ font-size: 24rpx;
+ color: #999999;
+}
+
+/* OTP操作按钮 */
+.otp-actions {
+ display: flex;
+ flex-direction: column;
+ justify-content: space-between;
+}
+
+.action-button {
+ width: 56rpx;
+ height: 56rpx;
+ border-radius: 28rpx;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ margin-bottom: 16rpx;
+}
+
+.action-button.edit {
+ background-color: #f0f7ff;
+}
+
+.action-button.delete {
+ background-color: #fff1f0;
+}
+
+.action-icon {
+ font-size: 32rpx;
+}
+
+.edit .action-icon {
+ color: #1890ff;
+}
+
+.delete .action-icon {
+ color: #ff4d4f;
+}
+
+/* 空状态 */
+.empty-state {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ padding: 120rpx 0;
+}
+
+.empty-image {
+ width: 240rpx;
+ height: 240rpx;
+ margin-bottom: 32rpx;
+}
+
+.empty-text {
+ font-size: 28rpx;
+ color: #999999;
+}
\ No newline at end of file
diff --git a/miniprogram-example/project.config.json b/miniprogram-example/project.config.json
new file mode 100644
index 0000000..3fb79ad
--- /dev/null
+++ b/miniprogram-example/project.config.json
@@ -0,0 +1,47 @@
+{
+ "description": "项目配置文件",
+ "packOptions": {
+ "ignore": [],
+ "include": []
+ },
+ "miniprogramRoot": "",
+ "compileType": "miniprogram",
+ "projectname": "OTPM",
+ "setting": {
+ "useCompilerPlugins": [
+ "sass"
+ ],
+ "babelSetting": {
+ "ignore": [],
+ "disablePlugins": [],
+ "outputPath": ""
+ },
+ "es6": true,
+ "enhance": true,
+ "minified": true,
+ "postcss": true,
+ "minifyWXSS": true,
+ "minifyWXML": true,
+ "uglifyFileName": true,
+ "packNpmManually": false,
+ "packNpmRelationList": [],
+ "ignoreUploadUnusedFiles": true,
+ "compileWorklet": false,
+ "uploadWithSourceMap": true,
+ "localPlugins": false,
+ "disableUseStrict": false,
+ "condition": false,
+ "swc": false,
+ "disableSWC": true
+ },
+ "simulatorType": "wechat",
+ "simulatorPluginLibVersion": {},
+ "condition": {},
+ "srcMiniprogramRoot": "",
+ "appid": "wxb6599459668b6b55",
+ "libVersion": "2.30.2",
+ "editorSetting": {
+ "tabIndent": "insertSpaces",
+ "tabSize": 2
+ }
+}
\ No newline at end of file
diff --git a/miniprogram-example/project.private.config.json b/miniprogram-example/project.private.config.json
new file mode 100644
index 0000000..6b2a738
--- /dev/null
+++ b/miniprogram-example/project.private.config.json
@@ -0,0 +1,23 @@
+{
+ "libVersion": "3.8.5",
+ "projectname": "OTPM",
+ "condition": {},
+ "setting": {
+ "urlCheck": true,
+ "coverView": false,
+ "lazyloadPlaceholderEnable": false,
+ "skylineRenderEnable": false,
+ "preloadBackgroundData": false,
+ "autoAudits": false,
+ "useApiHook": true,
+ "useApiHostProcess": true,
+ "showShadowRootInWxmlPanel": false,
+ "useStaticServer": false,
+ "useLanDebug": false,
+ "showES6CompileOption": false,
+ "compileHotReLoad": true,
+ "checkInvalidKey": true,
+ "ignoreDevUnusedFiles": true,
+ "bigPackageSizeSupport": false
+ }
+}
\ No newline at end of file
diff --git a/miniprogram-example/services/auth.js b/miniprogram-example/services/auth.js
new file mode 100644
index 0000000..4c95fa4
--- /dev/null
+++ b/miniprogram-example/services/auth.js
@@ -0,0 +1,84 @@
+// auth.js - 认证相关服务
+
+import request from '../utils/request';
+
+/**
+ * 微信登录
+ * 1. 调用wx.login获取code
+ * 2. 发送code到服务端换取token和openid
+ * 3. 保存token和openid到本地存储
+ */
+export const wxLogin = () => {
+ return new Promise((resolve, reject) => {
+ wx.login({
+ success: (res) => {
+ if (res.code) {
+ // 发送code到服务端
+ request({
+ url: '/login',
+ method: 'POST',
+ data: {
+ code: res.code
+ }
+ }).then(response => {
+ // 保存token和openid
+ if (response.data && response.data.token && response.data.openid) {
+ wx.setStorageSync('token', response.data.token);
+ wx.setStorageSync('openid', response.data.openid);
+ resolve(response.data);
+ } else {
+ reject(new Error('登录失败,服务器返回数据格式错误'));
+ }
+ }).catch(err => {
+ reject(err);
+ });
+ } else {
+ reject(new Error('登录失败,获取code失败: ' + res.errMsg));
+ }
+ },
+ fail: (err) => {
+ reject(new Error('微信登录失败: ' + err.errMsg));
+ }
+ });
+ });
+};
+
+/**
+ * 检查登录状态
+ * 1. 检查本地是否有token和openid
+ * 2. 如果有,验证token是否有效
+ * 3. 如果无效,清除本地存储并返回false
+ */
+export const checkLoginStatus = () => {
+ return new Promise((resolve, reject) => {
+ const token = wx.getStorageSync('token');
+ const openid = wx.getStorageSync('openid');
+
+ if (!token || !openid) {
+ resolve(false);
+ return;
+ }
+
+ // 验证token有效性
+ request({
+ url: '/verify-token',
+ method: 'POST'
+ }).then(() => {
+ resolve(true);
+ }).catch(() => {
+ // token无效,清除本地存储
+ wx.removeStorageSync('token');
+ wx.removeStorageSync('openid');
+ resolve(false);
+ });
+ });
+};
+
+/**
+ * 退出登录
+ */
+export const logout = () => {
+ wx.removeStorageSync('token');
+ wx.removeStorageSync('openid');
+ return Promise.resolve();
+};
\ No newline at end of file
diff --git a/miniprogram-example/services/otp.js b/miniprogram-example/services/otp.js
new file mode 100644
index 0000000..9da9925
--- /dev/null
+++ b/miniprogram-example/services/otp.js
@@ -0,0 +1,119 @@
+// otp.js - OTP相关服务
+
+import request from '../utils/request';
+
+/**
+ * 创建新的OTP
+ * @param {Object} params - 创建OTP的参数
+ * @param {string} params.name - OTP名称
+ * @param {string} params.issuer - 发行方
+ * @param {string} params.secret - 密钥
+ * @param {string} params.algorithm - 算法,默认为SHA1
+ * @param {number} params.digits - 位数,默认为6
+ * @param {number} params.period - 周期,默认为30秒
+ * @returns {Promise} - 返回创建结果
+ */
+export const createOTP = (params) => {
+ if (!params || !params.secret) {
+ return Promise.reject(new Error('缺少必要的参数: secret'));
+ }
+ return request({
+ url: '/otp',
+ method: 'POST',
+ data: {
+ name: params.name || '',
+ issuer: params.issuer || '',
+ secret: params.secret,
+ algorithm: params.algorithm || 'SHA1',
+ digits: params.digits || 6,
+ period: params.period || 30
+ }
+ }).catch(err => {
+ console.error('创建OTP失败:', err);
+ throw new Error('创建OTP失败: ' + (err.message || '未知错误'));
+ });
+};
+
+/**
+ * 获取用户所有OTP列表
+ * @returns {Promise} - 返回OTP列表
+ */
+export const getOTPList = () => {
+ return request({
+ url: '/otp',
+ method: 'GET'
+ }).catch(err => {
+ console.error('获取OTP列表失败:', err);
+ throw new Error('获取OTP列表失败: ' + (err.message || '未知错误'));
+ });
+};
+
+/**
+ * 获取指定OTP的当前验证码
+ * @param {string} id - OTP的ID
+ * @returns {Promise} - 返回当前验证码
+ */
+export const getOTPCode = (id) => {
+ if (!id) {
+ return Promise.reject(new Error('缺少必要的参数: id'));
+ }
+ return request({
+ url: `/otp/${id}/code`,
+ method: 'GET'
+ }).catch(err => {
+ console.error('获取OTP代码失败:', err);
+ throw new Error('获取OTP代码失败: ' + (err.message || '未知错误'));
+ });
+};
+
+/**
+ * 验证OTP
+ * @param {string} id - OTP的ID
+ * @param {string} code - 用户输入的验证码
+ * @returns {Promise} - 返回验证结果
+ */
+export const verifyOTP = (id, code) => {
+ if (!id || !code) {
+ return Promise.reject(new Error('缺少必要的参数: id或code'));
+ }
+ return request({
+ url: `/otp/${id}/verify`,
+ method: 'POST',
+ data: { code }
+ }).catch(err => {
+ console.error('验证OTP失败:', err);
+ throw new Error('验证OTP失败: ' + (err.message || '未知错误'));
+ });
+};
+
+/**
+ * 更新OTP信息
+ * @param {string} id - OTP的ID
+ * @param {Object} params - 更新的参数
+ * @returns {Promise} - 返回更新结果
+ */
+export const updateOTP = (id, params) => {
+ if (!id || !params) {
+ return Promise.reject(new Error('缺少必要的参数: id或params'));
+ }
+ return request({
+ url: `/otp/${id}`,
+ method: 'PUT',
+ data: params
+ }).catch(err => {
+ console.error('更新OTP失败:', err);
+ throw new Error('更新OTP失败: ' + (err.message || '未知错误'));
+ });
+};
+
+/**
+ * 删除OTP
+ * @param {string} id - OTP的ID
+ * @returns {Promise} - 返回删除结果
+ */
+export const deleteOTP = (id) => {
+ return request({
+ url: `/otp/${id}`,
+ method: 'DELETE'
+ });
+};
\ No newline at end of file
diff --git a/miniprogram-example/utils/request.js b/miniprogram-example/utils/request.js
new file mode 100644
index 0000000..9ea7e2e
--- /dev/null
+++ b/miniprogram-example/utils/request.js
@@ -0,0 +1,58 @@
+// request.js - 网络请求工具类
+
+const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名
+
+// 请求拦截器
+const request = (options) => {
+ return new Promise((resolve, reject) => {
+ const token = wx.getStorageSync('token');
+ const header = {
+ 'Content-Type': 'application/json',
+ ...options.header
+ };
+
+ // 如果有token,添加到请求头
+ if (token) {
+ header['Authorization'] = `Bearer ${token}`;
+ }
+
+ wx.request({
+ url: `${BASE_URL}${options.url}`,
+ method: options.method || 'GET',
+ data: options.data,
+ header: header,
+ success: (res) => {
+ // 处理业务错误
+ if (res.data.code !== 0) {
+ // token过期,直接清除并跳转登录
+ if (res.statusCode === 401) {
+ wx.removeStorageSync('token');
+ wx.removeStorageSync('openid');
+ reject(new Error('登录已过期,请重新登录'));
+ return;
+ }
+ reject(new Error(res.data.message || '请求失败'));
+ return;
+ }
+ resolve(res.data);
+ },
+ fail: reject
+ });
+ });
+};
+
+// 刷新token
+const refreshToken = () => {
+ return request({
+ url: '/refresh-token',
+ method: 'POST'
+ }).then(res => {
+ if (res.data && res.data.token) {
+ wx.setStorageSync('token', res.data.token);
+ return res.data.token;
+ }
+ throw new Error('Failed to refresh token');
+ });
+};
+
+export default request;
\ No newline at end of file
diff --git a/models/otp.go b/models/otp.go
new file mode 100644
index 0000000..8eaab88
--- /dev/null
+++ b/models/otp.go
@@ -0,0 +1,66 @@
+package models
+
+import (
+ "context"
+ "time"
+)
+
+// OTP represents a TOTP configuration
+type OTP struct {
+ ID int64 `json:"id" db:"id"`
+ UserID string `json:"user_id" db:"user_id" validate:"required"`
+ OpenID string `json:"openid" db:"openid" validate:"required"`
+ Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
+ Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
+ Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
+ Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
+ Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
+}
+
+// OTPParams represents common OTP parameters used in creation and update
+type OTPParams struct {
+ Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" validate:"omitempty,issuer"`
+ Secret string `json:"secret" validate:"required,otpsecret"`
+ Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
+ Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
+ Period int `json:"period" validate:"omitempty,min=30,max=60"`
+}
+
+// OTPRepository handles OTP data storage
+type OTPRepository struct {
+ // Add your database connection or ORM here
+}
+
+// Create creates a new OTP record
+func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
+ // Implement database creation logic
+ return nil
+}
+
+// FindByID finds an OTP by ID and user ID
+func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
+ // Implement database lookup logic
+ return nil, nil
+}
+
+// FindAllByUserID finds all OTPs for a user
+func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
+ // Implement database query logic
+ return nil, nil
+}
+
+// Update updates an existing OTP record
+func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
+ // Implement database update logic
+ return nil
+}
+
+// Delete deletes an OTP record
+func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
+ // Implement database deletion logic
+ return nil
+}
diff --git a/models/user.go b/models/user.go
new file mode 100644
index 0000000..f12399b
--- /dev/null
+++ b/models/user.go
@@ -0,0 +1,114 @@
+package models
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "time"
+
+ "github.com/jmoiron/sqlx"
+)
+
+// User represents a user in the system
+type User struct {
+ ID string `db:"id" json:"id"`
+ OpenID string `db:"openid" json:"openid"`
+ SessionKey string `db:"session_key" json:"-"`
+ CreatedAt time.Time `db:"created_at" json:"created_at"`
+ UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
+}
+
+// UserRepository handles user data operations
+type UserRepository struct {
+ db *sqlx.DB
+}
+
+// NewUserRepository creates a new UserRepository
+func NewUserRepository(db *sqlx.DB) *UserRepository {
+ return &UserRepository{db: db}
+}
+
+// FindByID finds a user by ID
+func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
+ var user User
+ query := `SELECT * FROM users WHERE id = ?`
+ err := r.db.GetContext(ctx, &user, query, id)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, fmt.Errorf("user not found: %w", err)
+ }
+ return nil, fmt.Errorf("failed to find user: %w", err)
+ }
+ return &user, nil
+}
+
+// FindByOpenID finds a user by OpenID
+func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) {
+ var user User
+ query := `SELECT * FROM users WHERE openid = ?`
+ err := r.db.GetContext(ctx, &user, query, openID)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil // User not found, but not an error
+ }
+ return nil, fmt.Errorf("failed to find user: %w", err)
+ }
+ return &user, nil
+}
+
+// Create creates a new user
+func (r *UserRepository) Create(ctx context.Context, user *User) error {
+ query := `
+ INSERT INTO users (id, openid, session_key, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?)
+ `
+ now := time.Now()
+ user.CreatedAt = now
+ user.UpdatedAt = now
+
+ _, err := r.db.ExecContext(
+ ctx,
+ query,
+ user.ID,
+ user.OpenID,
+ user.SessionKey,
+ user.CreatedAt,
+ user.UpdatedAt,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create user: %w", err)
+ }
+ return nil
+}
+
+// Update updates an existing user
+func (r *UserRepository) Update(ctx context.Context, user *User) error {
+ query := `
+ UPDATE users
+ SET session_key = ?, updated_at = ?
+ WHERE id = ?
+ `
+ user.UpdatedAt = time.Now()
+
+ _, err := r.db.ExecContext(
+ ctx,
+ query,
+ user.SessionKey,
+ user.UpdatedAt,
+ user.ID,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to update user: %w", err)
+ }
+ return nil
+}
+
+// Delete deletes a user
+func (r *UserRepository) Delete(ctx context.Context, id string) error {
+ query := `DELETE FROM users WHERE id = ?`
+ _, err := r.db.ExecContext(ctx, query, id)
+ if err != nil {
+ return fmt.Errorf("failed to delete user: %w", err)
+ }
+ return nil
+}
diff --git a/security/security.go b/security/security.go
new file mode 100644
index 0000000..7b69b81
--- /dev/null
+++ b/security/security.go
@@ -0,0 +1,332 @@
+package security
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/subtle"
+ "encoding/base64"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/argon2"
+)
+
+// SecurityService provides security functionality
+type SecurityService struct {
+ config *Config
+}
+
+// Config represents security configuration
+type Config struct {
+ // CSRF protection
+ CSRFTokenLength int
+ CSRFTokenExpiry time.Duration
+ CSRFCookieName string
+ CSRFHeaderName string
+ CSRFCookieSecure bool
+ CSRFCookieHTTPOnly bool
+ CSRFCookieSameSite http.SameSite
+
+ // Rate limiting
+ RateLimitRequests int
+ RateLimitWindow time.Duration
+
+ // Password hashing
+ Argon2Time uint32
+ Argon2Memory uint32
+ Argon2Threads uint8
+ Argon2KeyLen uint32
+ Argon2SaltLen uint32
+}
+
+// DefaultConfig returns the default security configuration
+func DefaultConfig() *Config {
+ return &Config{
+ // CSRF protection
+ CSRFTokenLength: 32,
+ CSRFTokenExpiry: 24 * time.Hour,
+ CSRFCookieName: "csrf_token",
+ CSRFHeaderName: "X-CSRF-Token",
+ CSRFCookieSecure: true,
+ CSRFCookieHTTPOnly: true,
+ CSRFCookieSameSite: http.SameSiteStrictMode,
+
+ // Rate limiting
+ RateLimitRequests: 100,
+ RateLimitWindow: time.Minute,
+
+ // Password hashing
+ Argon2Time: 1,
+ Argon2Memory: 64 * 1024,
+ Argon2Threads: 4,
+ Argon2KeyLen: 32,
+ Argon2SaltLen: 16,
+ }
+}
+
+// NewSecurityService creates a new SecurityService
+func NewSecurityService(config *Config) *SecurityService {
+ if config == nil {
+ config = DefaultConfig()
+ }
+ return &SecurityService{
+ config: config,
+ }
+}
+
+// GenerateCSRFToken generates a CSRF token
+func (s *SecurityService) GenerateCSRFToken() (string, error) {
+ // Generate random bytes
+ bytes := make([]byte, s.config.CSRFTokenLength)
+ if _, err := rand.Read(bytes); err != nil {
+ return "", fmt.Errorf("failed to generate random bytes: %w", err)
+ }
+
+ // Encode as base64
+ token := base64.StdEncoding.EncodeToString(bytes)
+ return token, nil
+}
+
+// SetCSRFCookie sets a CSRF cookie
+func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
+ http.SetCookie(w, &http.Cookie{
+ Name: s.config.CSRFCookieName,
+ Value: token,
+ Path: "/",
+ Expires: time.Now().Add(s.config.CSRFTokenExpiry),
+ Secure: s.config.CSRFCookieSecure,
+ HttpOnly: s.config.CSRFCookieHTTPOnly,
+ SameSite: s.config.CSRFCookieSameSite,
+ })
+}
+
+// ValidateCSRFToken validates a CSRF token
+func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
+ // Get token from cookie
+ cookie, err := r.Cookie(s.config.CSRFCookieName)
+ if err != nil {
+ return false
+ }
+ cookieToken := cookie.Value
+
+ // Get token from header
+ headerToken := r.Header.Get(s.config.CSRFHeaderName)
+
+ // Compare tokens
+ return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
+}
+
+// CSRFMiddleware creates a middleware that validates CSRF tokens
+func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Skip for GET, HEAD, OPTIONS, TRACE
+ if r.Method == http.MethodGet ||
+ r.Method == http.MethodHead ||
+ r.Method == http.MethodOptions ||
+ r.Method == http.MethodTrace {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // Validate CSRF token
+ if !s.ValidateCSRFToken(r) {
+ http.Error(w, "Invalid CSRF token", http.StatusForbidden)
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+}
+
+// RateLimiter represents a rate limiter
+type RateLimiter struct {
+ requests map[string][]time.Time
+ config *Config
+}
+
+// NewRateLimiter creates a new RateLimiter
+func NewRateLimiter(config *Config) *RateLimiter {
+ return &RateLimiter{
+ requests: make(map[string][]time.Time),
+ config: config,
+ }
+}
+
+// Allow checks if a request is allowed
+func (r *RateLimiter) Allow(key string) bool {
+ now := time.Now()
+ windowStart := now.Add(-r.config.RateLimitWindow)
+
+ // Get requests for key
+ requests := r.requests[key]
+
+ // Filter out old requests
+ var newRequests []time.Time
+ for _, t := range requests {
+ if t.After(windowStart) {
+ newRequests = append(newRequests, t)
+ }
+ }
+
+ // Check if rate limit is exceeded
+ if len(newRequests) >= r.config.RateLimitRequests {
+ return false
+ }
+
+ // Add current request
+ newRequests = append(newRequests, now)
+ r.requests[key] = newRequests
+
+ return true
+}
+
+// RateLimitMiddleware creates a middleware that limits request rate
+func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Get client IP
+ ip := getClientIP(r)
+
+ // Check if request is allowed
+ if !limiter.Allow(ip) {
+ http.Error(w, "Too many requests", http.StatusTooManyRequests)
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
+
+// getClientIP gets the client IP address
+func getClientIP(r *http.Request) string {
+ // Check X-Forwarded-For header
+ xForwardedFor := r.Header.Get("X-Forwarded-For")
+ if xForwardedFor != "" {
+ // X-Forwarded-For can contain multiple IPs, use the first one
+ ips := strings.Split(xForwardedFor, ",")
+ return strings.TrimSpace(ips[0])
+ }
+
+ // Check X-Real-IP header
+ xRealIP := r.Header.Get("X-Real-IP")
+ if xRealIP != "" {
+ return xRealIP
+ }
+
+ // Use RemoteAddr
+ return r.RemoteAddr
+}
+
+// HashPassword hashes a password using Argon2
+func (s *SecurityService) HashPassword(password string) (string, error) {
+ // Generate salt
+ salt := make([]byte, s.config.Argon2SaltLen)
+ if _, err := rand.Read(salt); err != nil {
+ return "", fmt.Errorf("failed to generate salt: %w", err)
+ }
+
+ // Hash password
+ hash := argon2.IDKey(
+ []byte(password),
+ salt,
+ s.config.Argon2Time,
+ s.config.Argon2Memory,
+ s.config.Argon2Threads,
+ s.config.Argon2KeyLen,
+ )
+
+ // Encode as base64
+ saltBase64 := base64.StdEncoding.EncodeToString(salt)
+ hashBase64 := base64.StdEncoding.EncodeToString(hash)
+
+ // Format as $argon2id$v=19$m=65536,t=1,p=4$$
+ return fmt.Sprintf(
+ "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
+ s.config.Argon2Memory,
+ s.config.Argon2Time,
+ s.config.Argon2Threads,
+ saltBase64,
+ hashBase64,
+ ), nil
+}
+
+// VerifyPassword verifies a password against a hash
+func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
+ // Parse encoded hash
+ parts := strings.Split(encodedHash, "$")
+ if len(parts) != 6 {
+ return false, fmt.Errorf("invalid hash format")
+ }
+
+ // Extract parameters
+ if parts[1] != "argon2id" {
+ return false, fmt.Errorf("unsupported hash algorithm")
+ }
+
+ var memory uint32
+ var time uint32
+ var threads uint8
+ _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
+ if err != nil {
+ return false, fmt.Errorf("failed to parse hash parameters: %w", err)
+ }
+
+ // Decode salt and hash
+ salt, err := base64.StdEncoding.DecodeString(parts[4])
+ if err != nil {
+ return false, fmt.Errorf("failed to decode salt: %w", err)
+ }
+
+ hash, err := base64.StdEncoding.DecodeString(parts[5])
+ if err != nil {
+ return false, fmt.Errorf("failed to decode hash: %w", err)
+ }
+
+ // Hash password with same parameters
+ newHash := argon2.IDKey(
+ []byte(password),
+ salt,
+ time,
+ memory,
+ threads,
+ uint32(len(hash)),
+ )
+
+ // Compare hashes
+ return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
+}
+
+// SecureHeadersMiddleware adds security headers to responses
+func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Add security headers
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("X-Frame-Options", "DENY")
+ w.Header().Set("X-XSS-Protection", "1; mode=block")
+ w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
+ w.Header().Set("Content-Security-Policy", "default-src 'self'")
+ w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
+
+ next.ServeHTTP(w, r)
+ })
+}
+
+// contextKey is a type for context keys
+type contextKey string
+
+// userIDKey is the key for user ID in context
+const userIDKey = contextKey("user_id")
+
+// WithUserID adds a user ID to the context
+func WithUserID(ctx context.Context, userID string) context.Context {
+ return context.WithValue(ctx, userIDKey, userID)
+}
+
+// GetUserID gets the user ID from the context
+func GetUserID(ctx context.Context) (string, bool) {
+ userID, ok := ctx.Value(userIDKey).(string)
+ return userID, ok
+}
diff --git a/server/server.go b/server/server.go
new file mode 100644
index 0000000..a10e990
--- /dev/null
+++ b/server/server.go
@@ -0,0 +1,190 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "os/signal"
+ "runtime"
+ "syscall"
+ "time"
+
+ "otpm/config"
+ "otpm/middleware"
+
+ "github.com/julienschmidt/httprouter"
+)
+
+// Server represents the HTTP server
+type Server struct {
+ server *http.Server
+ router *httprouter.Router
+ config *config.Config
+}
+
+// New creates a new server
+func New(cfg *config.Config) *Server {
+ router := httprouter.New()
+
+ server := &http.Server{
+ Addr: fmt.Sprintf(":%d", cfg.Server.Port),
+ Handler: router,
+ ReadTimeout: cfg.Server.ReadTimeout,
+ WriteTimeout: cfg.Server.WriteTimeout,
+ IdleTimeout: 120 * time.Second,
+ }
+
+ return &Server{
+ server: server,
+ router: router,
+ config: cfg,
+ }
+}
+
+// Start starts the server
+func (s *Server) Start() error {
+ // Apply global middleware in correct order with enhanced error handling
+ var handler http.Handler = s.router
+
+ // Logger should be first to capture all request details
+ handler = middleware.Logger(handler)
+
+ // CORS next to handle pre-flight requests
+ handler = middleware.CORS(handler)
+
+ // Then Timeout to enforce request deadlines
+ handler = middleware.Timeout(s.config.Server.Timeout)(handler)
+
+ // Recover should be outermost to catch any panics
+ handler = middleware.Recover(handler)
+
+ s.server.Handler = handler
+
+ // Log server configuration at startup
+ log.Printf("Server configuration:\n"+
+ "Address: %s\n"+
+ "Read Timeout: %v\n"+
+ "Write Timeout: %v\n"+
+ "Idle Timeout: %v\n"+
+ "Request Timeout: %v",
+ s.server.Addr,
+ s.server.ReadTimeout,
+ s.server.WriteTimeout,
+ s.server.IdleTimeout,
+ s.config.Server.Timeout,
+ )
+
+ // Start server in a goroutine
+ serverErr := make(chan error, 1)
+ go func() {
+ log.Printf("Server starting on %s", s.server.Addr)
+ if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ serverErr <- fmt.Errorf("server error: %w", err)
+ }
+ }()
+
+ // Wait for interrupt signal or server error
+ quit := make(chan os.Signal, 1)
+ signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
+
+ select {
+ case err := <-serverErr:
+ return err
+ case <-quit:
+ return s.Shutdown()
+ }
+}
+
+// Shutdown gracefully stops the server
+func (s *Server) Shutdown() error {
+ log.Println("Shutting down server...")
+
+ ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout)
+ defer cancel()
+
+ if err := s.server.Shutdown(ctx); err != nil {
+ return fmt.Errorf("graceful shutdown failed: %w", err)
+ }
+
+ log.Println("Server stopped gracefully")
+ return nil
+}
+
+// Router returns the router
+func (s *Server) Router() *httprouter.Router {
+ return s.router
+}
+
+// RegisterRoutes registers all routes
+func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
+ for pattern, handler := range routes {
+ s.router.Handle("GET", pattern, handler)
+ s.router.Handle("POST", pattern, handler)
+ s.router.Handle("PUT", pattern, handler)
+ s.router.Handle("DELETE", pattern, handler)
+ }
+}
+
+// RegisterAuthRoutes registers routes that require authentication
+func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
+ for pattern, handler := range routes {
+ // Apply authentication middleware
+ authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+ // Convert httprouter.Handle to http.HandlerFunc for middleware
+ wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Store params in request context
+ ctx := context.WithValue(r.Context(), "params", ps)
+ handler(w, r.WithContext(ctx), ps)
+ })
+
+ // Apply auth middleware
+ middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
+ }
+
+ s.router.Handle("GET", pattern, authHandler)
+ s.router.Handle("POST", pattern, authHandler)
+ s.router.Handle("PUT", pattern, authHandler)
+ s.router.Handle("DELETE", pattern, authHandler)
+ }
+}
+
+// RegisterHealthCheck registers an enhanced health check endpoint
+func (s *Server) RegisterHealthCheck() {
+ s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ response := map[string]interface{}{
+ "status": "ok",
+ "timestamp": time.Now().Format(time.RFC3339),
+ "version": "1.0.0", // Hardcoded version instead of from config
+ "system": map[string]interface{}{
+ "goroutines": runtime.NumGoroutine(),
+ "memory": getMemoryUsage(),
+ },
+ }
+
+ // Add database status if configured
+ if s.config.Database.DSN != "" {
+ dbStatus := "ok"
+ response["database"] = dbStatus
+ }
+
+ middleware.SuccessResponse(w, response)
+ })
+}
+
+// getMemoryUsage returns current memory usage in MB
+func getMemoryUsage() map[string]interface{} {
+ var m runtime.MemStats
+ runtime.ReadMemStats(&m)
+ return map[string]interface{}{
+ "alloc_mb": bToMb(m.Alloc),
+ "total_alloc_mb": bToMb(m.TotalAlloc),
+ "sys_mb": bToMb(m.Sys),
+ "num_gc": m.NumGC,
+ }
+}
+
+func bToMb(b uint64) float64 {
+ return float64(b) / 1024 / 1024
+}
diff --git a/services/auth.go b/services/auth.go
new file mode 100644
index 0000000..9ad2a85
--- /dev/null
+++ b/services/auth.go
@@ -0,0 +1,230 @@
+package services
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "time"
+
+ "github.com/golang-jwt/jwt"
+ "github.com/google/uuid"
+
+ "otpm/config"
+ "otpm/models"
+)
+
+// WeChatCode2SessionResponse represents the response from WeChat code2session API
+type WeChatCode2SessionResponse struct {
+ OpenID string `json:"openid"`
+ SessionKey string `json:"session_key"`
+ UnionID string `json:"unionid"`
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+// AuthService handles authentication related operations
+type AuthService struct {
+ config *config.Config
+ userRepo *models.UserRepository
+ httpClient *http.Client
+}
+
+// NewAuthService creates a new AuthService
+func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
+ return &AuthService{
+ config: cfg,
+ userRepo: userRepo,
+ httpClient: &http.Client{
+ Timeout: 10 * time.Second,
+ },
+ }
+}
+
+// LoginWithWeChatCode handles WeChat login
+func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
+ start := time.Now()
+
+ // Get OpenID and SessionKey from WeChat
+ sessionInfo, err := s.getWeChatSession(code)
+ if err != nil {
+ log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
+ return "", fmt.Errorf("failed to get WeChat session: %w", err)
+ }
+ log.Printf("WeChat session obtained for code %s (took %v)",
+ maskCode(code), time.Since(start))
+
+ // Find or create user
+ user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
+ if err != nil {
+ log.Printf("User lookup failed for OpenID %s: %v",
+ maskOpenID(sessionInfo.OpenID), err)
+ return "", fmt.Errorf("failed to find user: %w", err)
+ }
+
+ if user == nil {
+ // Create new user
+ user = &models.User{
+ ID: uuid.New().String(),
+ OpenID: sessionInfo.OpenID,
+ SessionKey: sessionInfo.SessionKey,
+ }
+ if err := s.userRepo.Create(ctx, user); err != nil {
+ log.Printf("User creation failed for OpenID %s: %v",
+ maskOpenID(sessionInfo.OpenID), err)
+ return "", fmt.Errorf("failed to create user: %w", err)
+ }
+ log.Printf("New user created with ID %s for OpenID %s",
+ user.ID, maskOpenID(sessionInfo.OpenID))
+ } else {
+ // Update session key
+ user.SessionKey = sessionInfo.SessionKey
+ if err := s.userRepo.Update(ctx, user); err != nil {
+ log.Printf("User update failed for ID %s: %v", user.ID, err)
+ return "", fmt.Errorf("failed to update user: %w", err)
+ }
+ log.Printf("User %s session key updated", user.ID)
+ }
+
+ // Generate JWT token
+ token, err := s.generateToken(user)
+ if err != nil {
+ log.Printf("Token generation failed for user %s: %v", user.ID, err)
+ return "", fmt.Errorf("failed to generate token: %w", err)
+ }
+
+ log.Printf("WeChat login completed for user %s (total time %v)",
+ user.ID, time.Since(start))
+ return token, nil
+}
+
+// maskCode masks sensitive parts of WeChat code for logging
+func maskCode(code string) string {
+ if len(code) < 8 {
+ return "****"
+ }
+ return code[:2] + "****" + code[len(code)-2:]
+}
+
+// maskOpenID masks sensitive parts of OpenID for logging
+func maskOpenID(openID string) string {
+ if len(openID) < 8 {
+ return "****"
+ }
+ return openID[:2] + "****" + openID[len(openID)-2:]
+}
+
+// getWeChatSession calls WeChat's code2session API
+func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
+ url := fmt.Sprintf(
+ "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
+ s.config.WeChat.AppID,
+ s.config.WeChat.AppSecret,
+ code,
+ )
+
+ resp, err := s.httpClient.Get(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to call WeChat API: %w", err)
+ }
+ defer resp.Body.Close()
+
+ var result WeChatCode2SessionResponse
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
+ }
+
+ if result.ErrCode != 0 {
+ return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
+ }
+
+ return &result, nil
+}
+
+// generateToken generates a JWT token for a user
+func (s *AuthService) generateToken(user *models.User) (string, error) {
+ now := time.Now()
+ claims := jwt.MapClaims{
+ "user_id": user.ID,
+ "exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
+ "iat": now.Unix(),
+ "iss": s.config.JWT.Issuer,
+ "aud": s.config.JWT.Audience,
+ "token_id": uuid.New().String(), // Unique token ID for tracking
+ }
+
+ // Use stronger signing method
+ token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
+ signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
+ if err != nil {
+ return "", fmt.Errorf("failed to sign token: %w", err)
+ }
+
+ log.Printf("Token generated for user %s (expires at %v)",
+ user.ID, now.Add(s.config.JWT.ExpireDelta))
+ return signedToken, nil
+}
+
+// ValidateToken validates a JWT token with additional checks
+func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
+ token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
+ // Verify signing method
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+ return []byte(s.config.JWT.Secret), nil
+ })
+
+ if err != nil {
+ if ve, ok := err.(*jwt.ValidationError); ok {
+ switch {
+ case ve.Errors&jwt.ValidationErrorMalformed != 0:
+ return nil, fmt.Errorf("malformed token")
+ case ve.Errors&jwt.ValidationErrorExpired != 0:
+ return nil, fmt.Errorf("token expired")
+ case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
+ return nil, fmt.Errorf("token not active yet")
+ default:
+ return nil, fmt.Errorf("token validation error: %w", err)
+ }
+ }
+ return nil, fmt.Errorf("failed to parse token: %w", err)
+ }
+
+ // Additional claims validation
+ if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
+ // Check issuer
+ if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
+ return nil, fmt.Errorf("invalid token issuer")
+ }
+ // Check audience
+ if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
+ return nil, fmt.Errorf("invalid token audience")
+ }
+ } else {
+ return nil, fmt.Errorf("invalid token claims")
+ }
+
+ return token, nil
+}
+
+// GetUserFromToken gets user information from a JWT token
+func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ return nil, fmt.Errorf("invalid token claims")
+ }
+
+ userID, ok := claims["user_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("user_id not found in token")
+ }
+
+ user, err := s.userRepo.FindByID(ctx, userID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to find user: %w", err)
+ }
+
+ return user, nil
+}
diff --git a/services/otp.go b/services/otp.go
new file mode 100644
index 0000000..c345bea
--- /dev/null
+++ b/services/otp.go
@@ -0,0 +1,358 @@
+package services
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/base32"
+ "encoding/binary"
+ "fmt"
+ "hash"
+ "log"
+ "strings"
+ "time"
+
+ "otpm/models"
+
+ "github.com/google/uuid"
+)
+
+// OTPService handles OTP related operations
+type OTPService struct {
+ otpRepo *models.OTPRepository
+}
+
+// NewOTPService creates a new OTPService
+func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
+ return &OTPService{
+ otpRepo: otpRepo,
+ }
+}
+
+// CreateOTP creates a new OTP with performance monitoring and logging
+func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
+ start := time.Now()
+
+ // Validate input
+ if err := s.validateOTPInput(input); err != nil {
+ log.Printf("OTP validation failed for user %s: %v", userID, err)
+ return nil, err
+ }
+
+ // Clean and standardize secret
+ secret := cleanSecret(input.Secret)
+
+ // Set defaults for optional fields
+ algorithm := strings.ToUpper(input.Algorithm)
+ if algorithm == "" {
+ algorithm = "SHA1"
+ }
+
+ digits := input.Digits
+ if digits == 0 {
+ digits = 6
+ }
+
+ period := input.Period
+ if period == 0 {
+ period = 30
+ }
+
+ // Create OTP
+ otp := &models.OTP{
+ ID: uuid.New().String(),
+ UserID: userID,
+ Name: input.Name,
+ Issuer: input.Issuer,
+ Secret: secret,
+ Algorithm: algorithm,
+ Digits: digits,
+ Period: period,
+ }
+
+ if err := s.otpRepo.Create(ctx, otp); err != nil {
+ log.Printf("Failed to create OTP for user %s: %v", userID, err)
+ return nil, fmt.Errorf("failed to create OTP: %w", err)
+ }
+
+ // Log successful creation (without exposing secret)
+ log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
+ otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
+
+ return otp, nil
+}
+
+// GetOTP gets an OTP by ID
+func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
+ otp, err := s.otpRepo.FindByID(ctx, id, userID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get OTP: %w", err)
+ }
+ return otp, nil
+}
+
+// ListOTPs lists all OTPs for a user
+func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
+ otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to list OTPs: %w", err)
+ }
+ return otps, nil
+}
+
+// UpdateOTP updates an OTP
+func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
+ // Get existing OTP
+ otp, err := s.otpRepo.FindByID(ctx, id, userID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get OTP: %w", err)
+ }
+
+ // Update fields
+ if input.Name != "" {
+ otp.Name = input.Name
+ }
+ if input.Issuer != "" {
+ otp.Issuer = input.Issuer
+ }
+ if input.Algorithm != "" {
+ otp.Algorithm = strings.ToUpper(input.Algorithm)
+ }
+ if input.Digits > 0 {
+ otp.Digits = input.Digits
+ }
+ if input.Period > 0 {
+ otp.Period = input.Period
+ }
+
+ // Validate updated OTP
+ if err := s.validateOTPInput(models.OTPParams{
+ Name: otp.Name,
+ Issuer: otp.Issuer,
+ Secret: otp.Secret,
+ Algorithm: otp.Algorithm,
+ Digits: otp.Digits,
+ Period: otp.Period,
+ }); err != nil {
+ return nil, err
+ }
+
+ if err := s.otpRepo.Update(ctx, otp); err != nil {
+ return nil, fmt.Errorf("failed to update OTP: %w", err)
+ }
+
+ return otp, nil
+}
+
+// DeleteOTP deletes an OTP
+func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
+ if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
+ return fmt.Errorf("failed to delete OTP: %w", err)
+ }
+ return nil
+}
+
+// GenerateCode generates a TOTP code with enhanced logging and error handling
+func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
+ start := time.Now()
+
+ otp, err := s.otpRepo.FindByID(ctx, id, userID)
+ if err != nil {
+ log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
+ return "", 0, fmt.Errorf("failed to get OTP: %w", err)
+ }
+
+ // Get current time step
+ now := time.Now().Unix()
+ timeStep := now / int64(otp.Period)
+
+ // Generate code
+ code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
+ if err != nil {
+ log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
+ return "", 0, fmt.Errorf("failed to generate code: %w", err)
+ }
+
+ // Calculate remaining seconds
+ remainingSeconds := otp.Period - int(now%int64(otp.Period))
+
+ // Log successful generation (without actual code)
+ log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
+ id, userID, time.Since(start), remainingSeconds)
+
+ return code, remainingSeconds, nil
+}
+
+// VerifyCode verifies a TOTP code with enhanced security and logging
+func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
+ start := time.Now()
+
+ // Basic input validation
+ if len(code) == 0 {
+ log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
+ return false, fmt.Errorf("code is required")
+ }
+
+ otp, err := s.otpRepo.FindByID(ctx, id, userID)
+ if err != nil {
+ log.Printf("Failed to find OTP %s for user %s during verification: %v",
+ id, userID, err)
+ return false, fmt.Errorf("failed to get OTP: %w", err)
+ }
+
+ // Get current and adjacent time steps
+ now := time.Now().Unix()
+ timeSteps := []int64{
+ (now - int64(otp.Period)) / int64(otp.Period),
+ now / int64(otp.Period),
+ (now + int64(otp.Period)) / int64(otp.Period),
+ }
+
+ // Check code against all time steps
+ for _, ts := range timeSteps {
+ expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
+ if err != nil {
+ log.Printf("Code generation failed for time step %d: %v", ts, err)
+ continue
+ }
+ if expectedCode == code {
+ // Log successful verification
+ log.Printf("Code verified successfully for OTP %s (user %s) in %v",
+ id, userID, time.Since(start))
+ return true, nil
+ }
+ }
+
+ // Log failed verification attempt
+ log.Printf("Invalid code provided for OTP %s (user %s) in %v",
+ id, userID, time.Since(start))
+
+ return false, nil
+}
+
+// validateOTPInput validates OTP input with detailed error messages
+func (s *OTPService) validateOTPInput(input models.OTPParams) error {
+ if input.Name == "" {
+ return fmt.Errorf("name is required")
+ }
+
+ if len(input.Name) > 100 {
+ return fmt.Errorf("name is too long (maximum 100 characters)")
+ }
+
+ if input.Secret == "" {
+ return fmt.Errorf("secret is required")
+ }
+
+ if !isValidBase32(input.Secret) {
+ return fmt.Errorf("invalid secret format: must be a valid base32 string")
+ }
+
+ // Secret length check (after base32 decoding)
+ secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
+ if len(secretBytes) < 10 {
+ return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
+ }
+
+ if input.Algorithm != "" {
+ if !isValidAlgorithm(input.Algorithm) {
+ return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
+ }
+ }
+
+ if input.Digits != 0 {
+ if input.Digits < 6 || input.Digits > 8 {
+ return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
+ }
+ }
+
+ if input.Period != 0 {
+ if input.Period < 30 || input.Period > 60 {
+ return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
+ }
+ }
+
+ return nil
+}
+
+// Helper functions
+
+func cleanSecret(secret string) string {
+ // Remove spaces and convert to upper case
+ secret = strings.TrimSpace(strings.ToUpper(secret))
+ // Remove any padding characters
+ return strings.TrimRight(secret, "=")
+}
+
+func isValidBase32(s string) bool {
+ // Try to decode the secret
+ _, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
+ return err == nil
+}
+
+func isValidAlgorithm(algorithm string) bool {
+ switch strings.ToUpper(algorithm) {
+ case "SHA1", "SHA256", "SHA512":
+ return true
+ default:
+ return false
+ }
+}
+
+func getHasher(algorithm string, key []byte) (hash.Hash, error) {
+ switch strings.ToUpper(algorithm) {
+ case "SHA1":
+ return hmac.New(sha1.New, key), nil
+ case "SHA256":
+ return hmac.New(sha256.New, key), nil
+ case "SHA512":
+ return hmac.New(sha512.New, key), nil
+ default:
+ return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
+ }
+}
+
+func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
+ // Decode secret
+ secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
+ if err != nil {
+ return "", fmt.Errorf("invalid secret: %w", err)
+ }
+
+ // Get initialized HMAC hasher with secret
+ hasher, err := getHasher(algorithm, secretBytes)
+ if err != nil {
+ return "", err
+ }
+
+ // Convert time step to bytes
+ timeBytes := make([]byte, 8)
+ binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
+
+ // Calculate HMAC
+ hasher.Write(timeBytes)
+ hash := hasher.Sum(nil)
+
+ // Get offset
+ offset := hash[len(hash)-1] & 0xf
+
+ // Generate 4-byte code
+ code := binary.BigEndian.Uint32(hash[offset : offset+4])
+ code = code & 0x7fffffff
+
+ // Get the specified number of digits
+ code = code % uint32(pow10(digits))
+
+ // Format code with leading zeros
+ return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
+}
+
+func pow10(n int) uint32 {
+ result := uint32(1)
+ for i := 0; i < n; i++ {
+ result *= 10
+ }
+ return result
+}
diff --git a/utils/utils.go b/utils/utils.go
index 113ac27..efd6713 100644
--- a/utils/utils.go
+++ b/utils/utils.go
@@ -1,20 +1,65 @@
package utils
import (
+ "context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
+ "fmt"
"net/http"
+ "strings"
+ "github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
+ "github.com/spf13/viper"
)
+// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
+ // 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ // 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递
h(w, r)
}
}
+func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ authHeader := r.Header.Get("Authorization")
+ if authHeader == "" {
+ http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
+ return
+ }
+
+ tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
+ secret := viper.GetString("auth.secret")
+
+ token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method")
+ }
+ return []byte(secret), nil
+ })
+
+ if err != nil || !token.Valid {
+ http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
+ return
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
+ return
+ }
+
+ type contextKey string
+ // 将 openid 存入上下文
+ ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
+ next.ServeHTTP(w, r.WithContext(ctx))
+ }
+}
+
+// AesDecrypt 函数用于AES解密
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
//Base64解码
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
@@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
return origData, nil
}
+// PKCS7UnPadding 函数用于去除PKCS7填充的密文
func PKCS7UnPadding(plantText []byte) []byte {
+ // 获取密文的长度
length := len(plantText)
+ // 如果密文长度大于0
if length > 0 {
+ // 获取最后一个字节的值,即填充的位数
unPadding := int(plantText[length-1])
+ // 返回去除填充后的密文
return plantText[:(length - unPadding)]
}
+ // 如果密文长度为0,则返回原密文
return plantText
}
diff --git a/validator/validator.go b/validator/validator.go
new file mode 100644
index 0000000..13cff58
--- /dev/null
+++ b/validator/validator.go
@@ -0,0 +1,316 @@
+package validator
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "reflect"
+ "regexp"
+ "strings"
+
+ "github.com/go-playground/validator/v10"
+)
+
+var (
+ validate *validator.Validate
+
+ // 自定义验证规则
+ customValidations = map[string]validator.Func{
+ "otpsecret": validateOTPSecret,
+ "password": validatePassword,
+ "issuer": validateIssuer,
+ "otpauth_uri": validateOTPAuthURI,
+ "no_xss": validateNoXSS,
+ }
+
+ // 常见的弱密码列表(实际使用时应该使用更完整的列表)
+ commonPasswords = map[string]bool{
+ "password123": true,
+ "12345678": true,
+ "qwerty123": true,
+ "admin123": true,
+ "letmein": true,
+ "welcome": true,
+ "password": true,
+ "admin": true,
+ }
+
+ // 预编译的XSS检测正则表达式
+ xssPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`(?i)`),
+ regexp.MustCompile(`(?i)javascript:`),
+ regexp.MustCompile(`(?i)data:text/html`),
+ regexp.MustCompile(`(?i)on\w+\s*=`),
+ regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
+ regexp.MustCompile(`(?i)<\s*iframe`),
+ regexp.MustCompile(`(?i)<\s*object`),
+ regexp.MustCompile(`(?i)<\s*embed`),
+ regexp.MustCompile(`(?i)<\s*style`),
+ regexp.MustCompile(`(?i)<\s*form`),
+ regexp.MustCompile(`(?i)<\s*applet`),
+ regexp.MustCompile(`(?i)<\s*meta`),
+ regexp.MustCompile(`(?i)expression\s*\(`),
+ regexp.MustCompile(`(?i)url\s*\(`),
+ }
+
+ // 预编译的正则表达式
+ base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
+ issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
+ otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
+ upperRegex = regexp.MustCompile(`[A-Z]`)
+ lowerRegex = regexp.MustCompile(`[a-z]`)
+ numberRegex = regexp.MustCompile(`[0-9]`)
+ specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
+)
+
+func init() {
+ validate = validator.New()
+
+ // 注册自定义验证规则
+ for tag, fn := range customValidations {
+ if err := validate.RegisterValidation(tag, fn); err != nil {
+ panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
+ }
+ }
+
+ // 使用json tag作为字段名
+ validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
+ name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
+ if name == "-" {
+ return ""
+ }
+ return name
+ })
+}
+
+// ValidateRequest validates a request body against a struct
+func ValidateRequest(r *http.Request, v interface{}) error {
+ if err := json.NewDecoder(r.Body).Decode(v); err != nil {
+ return fmt.Errorf("invalid request body: %w", err)
+ }
+
+ if err := validate.Struct(v); err != nil {
+ if validationErrors, ok := err.(validator.ValidationErrors); ok {
+ return NewValidationError(validationErrors)
+ }
+ return fmt.Errorf("validation error: %w", err)
+ }
+
+ return nil
+}
+
+// ValidationError represents a validation error
+type ValidationError struct {
+ Fields map[string]string `json:"fields"`
+}
+
+// Error implements the error interface
+func (e *ValidationError) Error() string {
+ var errors []string
+ for field, msg := range e.Fields {
+ errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
+ }
+ return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
+}
+
+// NewValidationError creates a new ValidationError from validator.ValidationErrors
+func NewValidationError(errors validator.ValidationErrors) *ValidationError {
+ fields := make(map[string]string)
+ for _, err := range errors {
+ fields[err.Field()] = getErrorMessage(err)
+ }
+ return &ValidationError{Fields: fields}
+}
+
+// getErrorMessage returns a human-readable error message for a validation error
+func getErrorMessage(err validator.FieldError) string {
+ switch err.Tag() {
+ case "required":
+ return "此字段为必填项"
+ case "email":
+ return "请输入有效的电子邮件地址"
+ case "min":
+ if err.Type().Kind() == reflect.String {
+ return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
+ }
+ return fmt.Sprintf("必须大于或等于 %s", err.Param())
+ case "max":
+ if err.Type().Kind() == reflect.String {
+ return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
+ }
+ return fmt.Sprintf("必须小于或等于 %s", err.Param())
+ case "len":
+ return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
+ case "oneof":
+ return fmt.Sprintf("必须是以下值之一: %s", err.Param())
+ case "otpsecret":
+ return "OTP密钥格式无效,必须是有效的Base32编码"
+ case "password":
+ return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符"
+ case "issuer":
+ return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
+ case "otpauth_uri":
+ return "OTP认证URI格式无效"
+ case "no_xss":
+ return "输入包含潜在的不安全内容"
+ case "numeric":
+ return "必须是数字"
+ default:
+ return fmt.Sprintf("验证失败: %s", err.Tag())
+ }
+}
+
+// Custom validation functions
+
+// validateOTPSecret validates an OTP secret
+func validateOTPSecret(fl validator.FieldLevel) bool {
+ secret := fl.Field().String()
+
+ if secret == "" {
+ return false
+ }
+
+ // OTP secret should be base32 encoded
+ if !base32Regex.MatchString(secret) {
+ return false
+ }
+
+ // Check length (typical OTP secrets are 16-64 characters)
+ validLength := len(secret) >= 16 && len(secret) <= 128
+
+ return validLength
+}
+
+// validatePassword validates a password
+func validatePassword(fl validator.FieldLevel) bool {
+ password := fl.Field().String()
+
+ // At least 10 characters long
+ if len(password) < 10 {
+ return false
+ }
+
+ // Check if it's a common password
+ if commonPasswords[strings.ToLower(password)] {
+ return false
+ }
+
+ // Check character types
+ hasUpper := upperRegex.MatchString(password)
+ hasLower := lowerRegex.MatchString(password)
+ hasNumber := numberRegex.MatchString(password)
+ hasSpecial := specialRegex.MatchString(password)
+
+ // Ensure password has enough complexity
+ complexity := 0
+ if hasUpper {
+ complexity++
+ }
+ if hasLower {
+ complexity++
+ }
+ if hasNumber {
+ complexity++
+ }
+ if hasSpecial {
+ complexity++
+ }
+
+ return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
+}
+
+// validateIssuer validates an issuer name
+func validateIssuer(fl validator.FieldLevel) bool {
+ issuer := fl.Field().String()
+
+ if issuer == "" {
+ return false
+ }
+
+ // Issuer should not contain special characters that could cause problems in URLs
+ if !issuerRegex.MatchString(issuer) {
+ return false
+ }
+
+ // Check length
+ validLength := len(issuer) >= 1 && len(issuer) <= 100
+
+ return validLength
+}
+
+// validateOTPAuthURI validates an otpauth URI
+func validateOTPAuthURI(fl validator.FieldLevel) bool {
+ uri := fl.Field().String()
+
+ if uri == "" {
+ return false
+ }
+
+ // Basic format check for otpauth URI
+ // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
+ return otpauthRegex.MatchString(uri)
+}
+
+// validateNoXSS checks if a string contains potential XSS payloads
+func validateNoXSS(fl validator.FieldLevel) bool {
+ value := fl.Field().String()
+
+ // 检查基本的HTML编码
+ if strings.Contains(value, "") ||
+ strings.Contains(value, "<") ||
+ strings.Contains(value, ">") {
+ return false
+ }
+
+ // 检查十六进制编码
+ if strings.Contains(strings.ToLower(value), "\\x3c") || // <
+ strings.Contains(strings.ToLower(value), "\\x3e") { // >
+ return false
+ }
+
+ // 检查Unicode编码
+ if strings.Contains(strings.ToLower(value), "\\u003c") || // <
+ strings.Contains(strings.ToLower(value), "\\u003e") { // >
+ return false
+ }
+
+ // 使用预编译的正则表达式检查XSS模式
+ for _, pattern := range xssPatterns {
+ if pattern.MatchString(value) {
+ return false
+ }
+ }
+
+ return true
+}
+
+// Request validation structs
+
+// LoginRequest represents a login request
+type LoginRequest struct {
+ Code string `json:"code" validate:"required,len=6|len=8,numeric"`
+}
+
+// CreateOTPRequest represents a request to create an OTP
+type CreateOTPRequest struct {
+ Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
+ Secret string `json:"secret" validate:"required,otpsecret"`
+ Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
+ Digits int `json:"digits" validate:"required,oneof=6 8"`
+ Period int `json:"period" validate:"required,oneof=30 60"`
+}
+
+// UpdateOTPRequest represents a request to update an OTP
+type UpdateOTPRequest struct {
+ Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
+ Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
+ Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
+ Period int `json:"period" validate:"omitempty,oneof=30 60"`
+}
+
+// VerifyOTPRequest represents a request to verify an OTP code
+type VerifyOTPRequest struct {
+ Code string `json:"code" validate:"required,len=6|len=8,numeric"`
+}