From bcd986e3f742899a581cc173652f663a0f62dd1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CxHuPo=E2=80=9D?= <7513325+vrocwang@users.noreply.github.com> Date: Fri, 23 May 2025 18:57:11 +0800 Subject: [PATCH] beta --- api/response.go | 149 ++++ cache/cache.go | 206 ++++++ cmd/root.go | 230 +++--- config.yaml | 30 +- config/config.go | 128 ++++ database/database.go | 64 -- database/db.go | 196 ++++++ database/migration.go | 160 +++++ docs/swagger.go | 663 ++++++++++++++++++ go.mod | 23 +- go.sum | 36 + handlers/auth_handler.go | 147 ++++ handlers/handler.go | 27 - handlers/login.go | 158 ----- handlers/otp.go | 70 -- handlers/otp_handler.go | 286 ++++++++ logger/logger.go | 204 ++++++ metrics/metrics.go | 193 +++++ middleware/middleware.go | 353 ++++++++++ miniprogram-example/app.js | 50 ++ miniprogram-example/app.json | 23 + miniprogram-example/app.wxss | 238 +++++++ miniprogram-example/pages/login/login.js | 48 ++ miniprogram-example/pages/login/login.json | 3 + miniprogram-example/pages/login/login.wxml | 30 + miniprogram-example/pages/login/login.wxss | 97 +++ miniprogram-example/pages/otp-add/index.js | 169 +++++ miniprogram-example/pages/otp-add/index.json | 3 + miniprogram-example/pages/otp-add/index.wxml | 119 ++++ miniprogram-example/pages/otp-add/index.wxss | 176 +++++ miniprogram-example/pages/otp-list/index.js | 213 ++++++ miniprogram-example/pages/otp-list/index.json | 3 + miniprogram-example/pages/otp-list/index.wxml | 59 ++ miniprogram-example/pages/otp-list/index.wxss | 201 ++++++ miniprogram-example/project.config.json | 47 ++ .../project.private.config.json | 23 + miniprogram-example/services/auth.js | 84 +++ miniprogram-example/services/otp.js | 87 +++ miniprogram-example/utils/request.js | 64 ++ models/otp.go | 195 ++++++ models/user.go | 114 +++ security/security.go | 332 +++++++++ server/server.go | 172 +++++ services/auth.go | 230 ++++++ services/otp.go | 358 ++++++++++ validator/validator.go | 159 +++++ 46 files changed, 6166 insertions(+), 454 deletions(-) create mode 100644 api/response.go create mode 100644 cache/cache.go create mode 100644 config/config.go delete mode 100644 database/database.go create mode 100644 database/db.go create mode 100644 database/migration.go create mode 100644 docs/swagger.go create mode 100644 handlers/auth_handler.go delete mode 100644 handlers/handler.go delete mode 100644 handlers/login.go delete mode 100644 handlers/otp.go create mode 100644 handlers/otp_handler.go create mode 100644 logger/logger.go create mode 100644 metrics/metrics.go create mode 100644 middleware/middleware.go create mode 100644 miniprogram-example/app.js create mode 100644 miniprogram-example/app.json create mode 100644 miniprogram-example/app.wxss create mode 100644 miniprogram-example/pages/login/login.js create mode 100644 miniprogram-example/pages/login/login.json create mode 100644 miniprogram-example/pages/login/login.wxml create mode 100644 miniprogram-example/pages/login/login.wxss create mode 100644 miniprogram-example/pages/otp-add/index.js create mode 100644 miniprogram-example/pages/otp-add/index.json create mode 100644 miniprogram-example/pages/otp-add/index.wxml create mode 100644 miniprogram-example/pages/otp-add/index.wxss create mode 100644 miniprogram-example/pages/otp-list/index.js create mode 100644 miniprogram-example/pages/otp-list/index.json create mode 100644 miniprogram-example/pages/otp-list/index.wxml create mode 100644 miniprogram-example/pages/otp-list/index.wxss create mode 100644 miniprogram-example/project.config.json create mode 100644 miniprogram-example/project.private.config.json create mode 100644 miniprogram-example/services/auth.js create mode 100644 miniprogram-example/services/otp.js create mode 100644 miniprogram-example/utils/request.js create mode 100644 models/otp.go create mode 100644 models/user.go create mode 100644 security/security.go create mode 100644 server/server.go create mode 100644 services/auth.go create mode 100644 services/otp.go create mode 100644 validator/validator.go diff --git a/api/response.go b/api/response.go new file mode 100644 index 0000000..8e24daf --- /dev/null +++ b/api/response.go @@ -0,0 +1,149 @@ +package api + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" +) + +// Common error codes +const ( + CodeSuccess = 0 + CodeInvalidParams = 400 + CodeUnauthorized = 401 + CodeForbidden = 403 + CodeNotFound = 404 + CodeInternalError = 500 + CodeServiceUnavail = 503 +) + +// Error represents an API error +type Error struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// Error implements the error interface +func (e *Error) Error() string { + return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message) +} + +// NewError creates a new API error +func NewError(code int, message string) *Error { + return &Error{ + Code: code, + Message: message, + } +} + +// Response represents a standard API response +type Response struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ResponseWriter wraps common response writing functions +type ResponseWriter struct { + http.ResponseWriter +} + +// NewResponseWriter creates a new ResponseWriter +func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { + return &ResponseWriter{w} +} + +// WriteJSON writes a JSON response +func (w *ResponseWriter) WriteJSON(code int, data interface{}) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + return json.NewEncoder(w).Encode(data) +} + +// WriteSuccess writes a success response +func (w *ResponseWriter) WriteSuccess(data interface{}) error { + return w.WriteJSON(http.StatusOK, Response{ + Code: CodeSuccess, + Message: "success", + Data: data, + }) +} + +// WriteError writes an error response +func (w *ResponseWriter) WriteError(err error) error { + var apiErr *Error + if errors.As(err, &apiErr) { + return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{ + Code: apiErr.Code, + Message: apiErr.Message, + }) + } + + // Handle unknown errors + return w.WriteJSON(http.StatusInternalServerError, Response{ + Code: CodeInternalError, + Message: "Internal Server Error", + }) +} + +// WriteErrorWithCode writes an error response with a specific code +func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error { + return w.WriteJSON(getHTTPStatus(code), Response{ + Code: code, + Message: message, + }) +} + +// getHTTPStatus maps API error codes to HTTP status codes +func getHTTPStatus(code int) int { + switch code { + case CodeSuccess: + return http.StatusOK + case CodeInvalidParams: + return http.StatusBadRequest + case CodeUnauthorized: + return http.StatusUnauthorized + case CodeForbidden: + return http.StatusForbidden + case CodeNotFound: + return http.StatusNotFound + case CodeServiceUnavail: + return http.StatusServiceUnavailable + default: + return http.StatusInternalServerError + } +} + +// Common errors +var ( + ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters") + ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized") + ErrForbidden = NewError(CodeForbidden, "Forbidden") + ErrNotFound = NewError(CodeNotFound, "Resource not found") + ErrInternalError = NewError(CodeInternalError, "Internal server error") + ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable") +) + +// ValidationError creates an error for invalid parameters +func ValidationError(message string) *Error { + return NewError(CodeInvalidParams, message) +} + +// NotFoundError creates an error for not found resources +func NotFoundError(resource string) *Error { + return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource)) +} + +// ForbiddenError creates an error for forbidden actions +func ForbiddenError(message string) *Error { + return NewError(CodeForbidden, message) +} + +// InternalError creates an error for internal server errors +func InternalError(err error) *Error { + if err == nil { + return ErrInternalError + } + return NewError(CodeInternalError, err.Error()) +} diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..08d3041 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,206 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" +) + +// Item represents a cache item +type Item struct { + Value []byte + Expiration int64 +} + +// Expired returns true if the item has expired +func (item Item) Expired() bool { + if item.Expiration == 0 { + return false + } + return time.Now().UnixNano() > item.Expiration +} + +// Cache represents an in-memory cache +type Cache struct { + items map[string]Item + mu sync.RWMutex + defaultExpiration time.Duration + cleanupInterval time.Duration + stopCleanup chan bool +} + +// New creates a new cache with the given default expiration and cleanup interval +func New(defaultExpiration, cleanupInterval time.Duration) *Cache { + cache := &Cache{ + items: make(map[string]Item), + defaultExpiration: defaultExpiration, + cleanupInterval: cleanupInterval, + stopCleanup: make(chan bool), + } + + // Start cleanup goroutine if cleanup interval > 0 + if cleanupInterval > 0 { + go cache.startCleanup() + } + + return cache +} + +// startCleanup starts the cleanup process +func (c *Cache) startCleanup() { + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.DeleteExpired() + case <-c.stopCleanup: + return + } + } +} + +// Set adds an item to the cache with the given key and expiration +func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error { + // Convert value to bytes + var valueBytes []byte + var err error + + switch v := value.(type) { + case []byte: + valueBytes = v + case string: + valueBytes = []byte(v) + default: + valueBytes, err = json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal value: %w", err) + } + } + + // Calculate expiration + var exp int64 + if expiration == 0 { + if c.defaultExpiration > 0 { + exp = time.Now().Add(c.defaultExpiration).UnixNano() + } + } else if expiration > 0 { + exp = time.Now().Add(expiration).UnixNano() + } + + c.mu.Lock() + c.items[key] = Item{ + Value: valueBytes, + Expiration: exp, + } + c.mu.Unlock() + + return nil +} + +// Get gets an item from the cache +func (c *Cache) Get(key string, value interface{}) (bool, error) { + c.mu.RLock() + item, found := c.items[key] + c.mu.RUnlock() + + if !found { + return false, nil + } + + // Check if item has expired + if item.Expired() { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() + return false, nil + } + + // Unmarshal value + switch v := value.(type) { + case *[]byte: + *v = item.Value + case *string: + *v = string(item.Value) + default: + if err := json.Unmarshal(item.Value, value); err != nil { + return true, fmt.Errorf("failed to unmarshal value: %w", err) + } + } + + return true, nil +} + +// Delete deletes an item from the cache +func (c *Cache) Delete(key string) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() +} + +// DeleteExpired deletes all expired items from the cache +func (c *Cache) DeleteExpired() { + now := time.Now().UnixNano() + + c.mu.Lock() + for k, v := range c.items { + if v.Expiration > 0 && now > v.Expiration { + delete(c.items, k) + } + } + c.mu.Unlock() +} + +// Clear deletes all items from the cache +func (c *Cache) Clear() { + c.mu.Lock() + c.items = make(map[string]Item) + c.mu.Unlock() +} + +// Close stops the cleanup goroutine +func (c *Cache) Close() { + if c.cleanupInterval > 0 { + c.stopCleanup <- true + } +} + +// CacheService provides caching functionality +type CacheService struct { + cache *Cache +} + +// NewCacheService creates a new CacheService +func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService { + return &CacheService{ + cache: New(defaultExpiration, cleanupInterval), + } +} + +// Set adds an item to the cache +func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error { + return s.cache.Set(key, value, expiration) +} + +// Get gets an item from the cache +func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) { + return s.cache.Get(key, value) +} + +// Delete deletes an item from the cache +func (s *CacheService) Delete(ctx context.Context, key string) { + s.cache.Delete(key) +} + +// Clear deletes all items from the cache +func (s *CacheService) Clear(ctx context.Context) { + s.cache.Clear() +} + +// Close closes the cache +func (s *CacheService) Close() { + s.cache.Close() +} diff --git a/cmd/root.go b/cmd/root.go index 9279368..bdc3f84 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,151 +2,139 @@ package cmd import ( "context" - "errors" "fmt" "log" "net/http" "os" "os/signal" + "syscall" + + "github.com/spf13/viper" + + "otpm/config" "otpm/database" "otpm/handlers" - "otpm/utils" - "syscall" - "time" - - "github.com/jmoiron/sqlx" - "github.com/julienschmidt/httprouter" - "github.com/spf13/cobra" - "github.com/spf13/viper" + "otpm/models" + "otpm/server" + "otpm/services" ) -var rootCmd = &cobra.Command{ - Use: "otpm", - Short: "otp backend for microapp on wechat", - Run: func(cmd *cobra.Command, args []string) { - startApp() - }, -} - -func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } -} - func init() { - cobra.OnInitialize(initConfig) + // Set config file with multi-environment support + viper.SetConfigName("config") + viper.SetConfigType("yaml") + viper.AddConfigPath(".") - rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)") - rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)") - rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string") - rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on") - - viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver")) - viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn")) - viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) -} - -func initConfig() { - if cfgFile := viper.GetString("config"); cfgFile != "" { - viper.SetConfigFile(cfgFile) - } else { - viper.AddConfigPath(".") - viper.SetConfigName("config") - viper.SetConfigType("yaml") + // Set environment specific config (e.g. config.production.yaml) + env := os.Getenv("OTPM_ENV") + if env != "" { + viper.SetConfigName(fmt.Sprintf("config.%s", env)) } - if err := viper.ReadInConfig(); err != nil { - log.Fatalf("Error reading config file: %v", err) - } + // Set default values + viper.SetDefault("server.port", "8080") + viper.SetDefault("server.timeout.read", "15s") + viper.SetDefault("server.timeout.write", "15s") + viper.SetDefault("server.timeout.idle", "60s") + viper.SetDefault("database.max_open_conns", 25) + viper.SetDefault("database.max_idle_conns", 5) + viper.SetDefault("database.conn_max_lifetime", "5m") + + // Set environment variable prefix + viper.SetEnvPrefix("OTPM") + viper.AutomaticEnv() + + // Bind environment variables + viper.BindEnv("database.url", "OTPM_DB_URL") + viper.BindEnv("database.password", "OTPM_DB_PASSWORD") } -type App struct { - db *sqlx.DB - router http.Handler - port int -} - -func NewApp() (*App, error) { - db, err := connectDB() +// Execute is the entry point for the application +func Execute() error { + // Load configuration + cfg, err := config.LoadConfig() if err != nil { - return nil, err + return fmt.Errorf("failed to load config: %w", err) } - if err := runMigrations(db); err != nil { - return nil, err - } - - router := setupRouter(db) - - port := viper.GetInt("port") - - return &App{ - db: db, - router: router, - port: port, - }, nil -} - -func connectDB() (*sqlx.DB, error) { - db, err := database.InitDB() - if err != nil { - return nil, fmt.Errorf("failed to connect to the database: %v", err) - } - return db, nil -} - -func runMigrations(db *sqlx.DB) error { - if err := database.MigrateDB(db); err != nil { - log.Fatalf("Error migrating the database: %v", err) - return fmt.Errorf("error migrating the database: %w", err) - } - return nil -} - -func setupRouter(db *sqlx.DB) http.Handler { - handler := &handlers.Handler{DB: db} - - router := httprouter.New() - router.POST("/login", utils.AdaptHandler(handler.Login)) - router.POST("/refresh", utils.AdaptHandler(utils.AuthMiddleware(handler.RefreshToken))) - router.POST("/set", utils.AdaptHandler(utils.AuthMiddleware(handler.UpdateOrCreateOtp))) - router.GET("/get", utils.AdaptHandler(utils.AuthMiddleware(handler.GetOtp))) - - return router -} - -func (a *App) Start() error { - server := &http.Server{Addr: fmt.Sprintf(":%d", a.port), Handler: a.router} - go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("Failed to start server: %v", err) - } - }() - - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - log.Println("Shutting down server...") - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + // Create context with cancellation + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - return server.Shutdown(ctx) -} + // Setup signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Printf("Received signal: %v", sig) + cancel() + }() -func startApp() { - app, err := NewApp() + // Initialize database + db, err := database.New(&cfg.Database) if err != nil { - log.Fatalf("Failed to initialize application: %v", err) + return fmt.Errorf("failed to initialize database: %w", err) } - defer func() { - if err := app.db.Close(); err != nil { - log.Printf("Failed to close database connection: %v", err) + defer db.Close() + + // Run database migrations + if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil { + return fmt.Errorf("failed to run migrations: %w", err) + } + + // Initialize repositories + userRepo := models.NewUserRepository(db.DB) + otpRepo := models.NewOTPRepository(db.DB) + + // Initialize services + authService := services.NewAuthService(cfg, userRepo) + otpService := services.NewOTPService(otpRepo) + + // Initialize handlers + authHandler := handlers.NewAuthHandler(authService) + otpHandler := handlers.NewOTPHandler(otpService) + + // Create and configure server + srv := server.New(cfg) + + // Register health check endpoint + srv.RegisterHealthCheck() + + // Register public routes with type conversion + authRoutes := make(map[string]http.Handler) + for path, handler := range authHandler.Routes() { + authRoutes[path] = http.HandlerFunc(handler) + } + srv.RegisterRoutes(authRoutes) + + // Register authenticated routes with type conversion + otpRoutes := make(map[string]http.Handler) + for path, handler := range otpHandler.Routes() { + otpRoutes[path] = http.HandlerFunc(handler) + } + srv.RegisterAuthRoutes(otpRoutes) + + // Start server in goroutine + serverErr := make(chan error, 1) + go func() { + log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port) + if err := srv.Start(); err != nil { + serverErr <- fmt.Errorf("server error: %w", err) } }() - if err := app.Start(); err != nil { - log.Fatalf("Failed to start application: %v", err) + + // Wait for shutdown signal or server error + select { + case err := <-serverErr: + return err + case <-ctx.Done(): + // Graceful shutdown with timeout + log.Println("Shutting down server...") + if err := srv.Shutdown(); err != nil { + return fmt.Errorf("server shutdown error: %w", err) + } + log.Println("Server stopped gracefully") } + + return nil } diff --git a/config.yaml b/config.yaml index 7ab9dcb..da2013f 100644 --- a/config.yaml +++ b/config.yaml @@ -1,15 +1,23 @@ server: - name: "otpm" -database: - driver: sqlite - dsn: otpm.sqlite - skip_migration: false -port: 8080 + port: 8080 + read_timeout: 15s + write_timeout: 15s + shutdown_timeout: 5s -auth: - secret: "secret" - ttl: 3600 +database: + driver: sqlite3 + dsn: otpm.sqlite + max_open_conns: 25 + max_idle_conns: 25 + max_lifetime: 5m + skip_migration: false + +jwt: + secret: "your-jwt-secret-key-change-this-in-production" + expire_delta: 24h + refresh_delta: 168h + signing_method: HS256 wechat: - appid: "wx57d1033974eb5250" - secret: "be494c2a81df685a40b9a74e1736b15d" \ No newline at end of file + app_id: "your-wechat-app-id" + app_secret: "your-wechat-app-secret" \ No newline at end of file diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..7670caf --- /dev/null +++ b/config/config.go @@ -0,0 +1,128 @@ +package config + +import ( + "fmt" + "time" + + "github.com/spf13/viper" +) + +// Config holds all configuration for the application +type Config struct { + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + JWT JWTConfig `mapstructure:"jwt"` + WeChat WeChatConfig `mapstructure:"wechat"` +} + +// ServerConfig holds all server related configuration +type ServerConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + ReadTimeout time.Duration `mapstructure:"read_timeout"` + WriteTimeout time.Duration `mapstructure:"write_timeout"` + ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"` + Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout +} + +// DatabaseConfig holds all database related configuration +type DatabaseConfig struct { + Driver string `mapstructure:"driver"` + DSN string `mapstructure:"dsn"` + MaxOpenConns int `mapstructure:"max_open_conns"` + MaxIdleConns int `mapstructure:"max_idle_conns"` + MaxLifetime time.Duration `mapstructure:"max_lifetime"` + SkipMigration bool `mapstructure:"skip_migration"` +} + +// JWTConfig holds all JWT related configuration +type JWTConfig struct { + Secret string `mapstructure:"secret"` + ExpireDelta time.Duration `mapstructure:"expire_delta"` + RefreshDelta time.Duration `mapstructure:"refresh_delta"` + SigningMethod string `mapstructure:"signing_method"` + Issuer string `mapstructure:"issuer"` + Audience string `mapstructure:"audience"` +} + +// WeChatConfig holds all WeChat related configuration +type WeChatConfig struct { + AppID string `mapstructure:"app_id"` + AppSecret string `mapstructure:"app_secret"` +} + +// LoadConfig loads the configuration from file and environment variables +func LoadConfig() (*Config, error) { + // Set default values + setDefaults() + + // Read config file + if err := viper.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var config Config + if err := viper.Unmarshal(&config); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Validate config + if err := validateConfig(&config); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + return &config, nil +} + +// setDefaults sets default values for configuration +func setDefaults() { + // Server defaults + viper.SetDefault("server.port", 8080) + viper.SetDefault("server.read_timeout", "15s") + viper.SetDefault("server.write_timeout", "15s") + viper.SetDefault("server.shutdown_timeout", "5s") + viper.SetDefault("server.timeout", "30s") // Default request processing timeout + + // Database defaults + viper.SetDefault("database.driver", "sqlite3") + viper.SetDefault("database.max_open_conns", 25) + viper.SetDefault("database.max_idle_conns", 25) + viper.SetDefault("database.max_lifetime", "5m") + viper.SetDefault("database.skip_migration", false) + + // JWT defaults + viper.SetDefault("jwt.expire_delta", "24h") + viper.SetDefault("jwt.refresh_delta", "168h") // 7 days + viper.SetDefault("jwt.signing_method", "HS256") + viper.SetDefault("jwt.issuer", "otpm") + viper.SetDefault("jwt.audience", "otpm-client") +} + +// validateConfig validates the configuration +func validateConfig(config *Config) error { + if config.Server.Port < 1 || config.Server.Port > 65535 { + return fmt.Errorf("invalid port number: %d", config.Server.Port) + } + + if config.Database.Driver == "" { + return fmt.Errorf("database driver is required") + } + + if config.Database.DSN == "" { + return fmt.Errorf("database DSN is required") + } + + if config.JWT.Secret == "" { + return fmt.Errorf("JWT secret is required") + } + + if config.WeChat.AppID == "" { + return fmt.Errorf("WeChat AppID is required") + } + + if config.WeChat.AppSecret == "" { + return fmt.Errorf("WeChat AppSecret is required") + } + + return nil +} diff --git a/database/database.go b/database/database.go deleted file mode 100644 index bad6d8f..0000000 --- a/database/database.go +++ /dev/null @@ -1,64 +0,0 @@ -package database - -import ( - _ "embed" - "fmt" - "log" - - _ "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx" - _ "github.com/lib/pq" - "github.com/spf13/viper" - _ "modernc.org/sqlite" -) - -var ( - //go:embed init/users.sql - userTable string - //go:embed init/otp.sql - otpTable string -) - -func InitDB() (*sqlx.DB, error) { - driver := viper.GetString("database.driver") - dsn := viper.GetString("database.dsn") - - db, err := sqlx.Open(driver, dsn) - if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) - } - - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) - } - - log.Println("Connected to database!") - return db, nil -} - -func MigrateDB(db *sqlx.DB) error { - // 检查是否需要执行迁移 - skipMigration := viper.GetBool("database.skip_migration") - if skipMigration { - log.Println("Skipping database migration as configured") - return nil - } - - // 执行用户表迁移 - if _, err := db.Exec(userTable); err != nil { - log.Printf("Warning: failed to create user migration: %v", err) - // 继续执行,不返回错误 - } else { - log.Println("User table migration completed successfully") - } - - // 执行OTP表迁移 - if _, err := db.Exec(otpTable); err != nil { - log.Printf("Warning: failed to create otp migration: %v", err) - // 继续执行,不返回错误 - } else { - log.Println("OTP table migration completed successfully") - } - - return nil -} diff --git a/database/db.go b/database/db.go new file mode 100644 index 0000000..43089ce --- /dev/null +++ b/database/db.go @@ -0,0 +1,196 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + "time" + + "otpm/config" + + "github.com/jmoiron/sqlx" +) + +// DB wraps sqlx.DB to provide additional functionality +type DB struct { + *sqlx.DB +} + +// New creates a new database connection +func New(cfg *config.DatabaseConfig) (*DB, error) { + db, err := sqlx.Open(cfg.Driver, cfg.DSN) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + // Configure connection pool with optimized settings + db.SetMaxOpenConns(cfg.MaxOpenConns) + db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections + db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn + db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes + + // Verify connection with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + log.Println("Successfully connected to database") + return &DB{db}, nil +} + +// WithTx executes a function within a transaction with retry logic +func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error { + const maxRetries = 3 + var lastErr error + + // Default transaction options + opts := &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + } + + for attempt := 1; attempt <= maxRetries; attempt++ { + start := time.Now() + + tx, err := db.BeginTxx(ctx, opts) + if err != nil { + if isRetryableError(err) && attempt < maxRetries { + lastErr = err + time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff + continue + } + return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err) + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } + }() + + if err := fn(tx); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err) + } + + if isRetryableError(err) && attempt < maxRetries { + lastErr = err + time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) + continue + } + return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err) + } + + // Log long-running transactions + if elapsed := time.Since(start); elapsed > 500*time.Millisecond { + log.Printf("Transaction completed in %v", elapsed) + } + + if err := tx.Commit(); err != nil { + if isRetryableError(err) && attempt < maxRetries { + lastErr = err + time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) + continue + } + return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err) + } + + return nil + } + + return lastErr +} + +// isRetryableError checks if an error is likely to succeed on retry +func isRetryableError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "deadlock") || + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "try again") || + strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "busy") || + strings.Contains(errStr, "locked") +} + +// ExecContext executes a query with adaptive timeout +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error { + // Set timeout based on query complexity + timeout := 5 * time.Second + if strings.Contains(strings.ToLower(query), "insert") || + strings.Contains(strings.ToLower(query), "update") || + strings.Contains(strings.ToLower(query), "delete") { + timeout = 10 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + start := time.Now() + _, err := db.DB.ExecContext(ctx, query, args...) + elapsed := time.Since(start) + + // Log slow queries + if elapsed > timeout/2 { + log.Printf("Slow query execution detected: %s (took %v)", query, elapsed) + } + + if err != nil { + return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err) + } + + return nil +} + +// QueryRowContext executes a query that returns a single row with timeout +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + return db.DB.QueryRowxContext(ctx, query, args...) +} + +// QueryContext executes a query that returns multiple rows with adaptive timeout +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + // Set timeout based on query complexity + timeout := 5 * time.Second + if strings.Contains(strings.ToLower(query), "join") || + strings.Contains(strings.ToLower(query), "group by") || + strings.Contains(strings.ToLower(query), "order by") { + timeout = 15 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + start := time.Now() + rows, err := db.DB.QueryxContext(ctx, query, args...) + elapsed := time.Since(start) + + // Log slow queries + if elapsed > timeout/2 { + log.Printf("Slow query detected: %s (took %v)", query, elapsed) + } + + if err != nil { + return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err) + } + + return rows, nil +} + +// Close closes the database connection +func (db *DB) Close() error { + if err := db.DB.Close(); err != nil { + return fmt.Errorf("failed to close database connection: %w", err) + } + return nil +} diff --git a/database/migration.go b/database/migration.go new file mode 100644 index 0000000..df15a97 --- /dev/null +++ b/database/migration.go @@ -0,0 +1,160 @@ +package database + +import ( + "context" + "fmt" + "log" + "time" + + _ "embed" + + "github.com/jmoiron/sqlx" +) + +var ( + //go:embed init/users.sql + userTable string + //go:embed init/otp.sql + otpTable string +) + +// Migration represents a database migration +type Migration struct { + Name string + SQL string + Version int +} + +// Migrations is a list of all migrations +var Migrations = []Migration{ + { + Name: "Create users table", + SQL: userTable, + Version: 1, + }, + { + Name: "Create OTP table", + SQL: otpTable, + Version: 2, + }, +} + +// MigrationRecord represents a record in the migrations table +type MigrationRecord struct { + ID int `db:"id"` + Version int `db:"version"` + Name string `db:"name"` + AppliedAt time.Time `db:"applied_at"` +} + +// ensureMigrationsTable ensures that the migrations table exists +func ensureMigrationsTable(db *sqlx.DB) error { + query := ` + CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version INTEGER NOT NULL, + name TEXT NOT NULL, + applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + ` + + _, err := db.Exec(query) + if err != nil { + return fmt.Errorf("failed to create migrations table: %w", err) + } + + return nil +} + +// getAppliedMigrations gets all applied migrations +func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) { + var records []MigrationRecord + if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil { + return nil, fmt.Errorf("failed to get applied migrations: %w", err) + } + + result := make(map[int]MigrationRecord) + for _, record := range records { + result[record.Version] = record + } + + return result, nil +} + +// Migrate runs all pending migrations +func Migrate(db *sqlx.DB, skipMigration bool) error { + if skipMigration { + log.Println("Skipping database migration as configured") + return nil + } + + // Ensure migrations table exists + if err := ensureMigrationsTable(db); err != nil { + return err + } + + // Get applied migrations + appliedMigrations, err := getAppliedMigrations(db) + if err != nil { + return err + } + + // Apply pending migrations + for _, migration := range Migrations { + if _, ok := appliedMigrations[migration.Version]; ok { + log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name) + continue + } + + log.Printf("Applying migration %d: %s", migration.Version, migration.Name) + + // Start a transaction for this migration + tx, err := db.Beginx() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + // Execute migration + if _, err := tx.Exec(migration.SQL); err != nil { + tx.Rollback() + return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err) + } + + // Record migration + if _, err := tx.Exec( + "INSERT INTO migrations (version, name) VALUES (?, ?)", + migration.Version, migration.Name, + ); err != nil { + tx.Rollback() + return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err) + } + + // Commit transaction + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err) + } + + log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name) + } + + return nil +} + +// MigrateWithContext runs all pending migrations with context +func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error { + // Create a channel to signal completion + done := make(chan error, 1) + + // Run migration in a goroutine + go func() { + done <- Migrate(db, skipMigration) + }() + + // Wait for migration to complete or context to be canceled + select { + case err := <-done: + return err + case <-ctx.Done(): + return fmt.Errorf("migration canceled: %w", ctx.Err()) + } +} diff --git a/docs/swagger.go b/docs/swagger.go new file mode 100644 index 0000000..ccf05ac --- /dev/null +++ b/docs/swagger.go @@ -0,0 +1,663 @@ +// Package docs provides API documentation using Swagger/OpenAPI +package docs + +import ( + "encoding/json" + "net/http" +) + +// SwaggerInfo holds the API information +var SwaggerInfo = struct { + Title string + Description string + Version string + Host string + BasePath string + Schemes []string +}{ + Title: "OTPM API", + Description: "API for One-Time Password Manager", + Version: "1.0.0", + Host: "localhost:8080", + BasePath: "/", + Schemes: []string{"http", "https"}, +} + +// SwaggerJSON returns the Swagger JSON +func SwaggerJSON() []byte { + swagger := map[string]interface{}{ + "swagger": "2.0", + "info": map[string]interface{}{ + "title": SwaggerInfo.Title, + "description": SwaggerInfo.Description, + "version": SwaggerInfo.Version, + }, + "host": SwaggerInfo.Host, + "basePath": SwaggerInfo.BasePath, + "schemes": SwaggerInfo.Schemes, + "paths": getPaths(), + "definitions": map[string]interface{}{ + "LoginRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "WeChat authorization code", + }, + }, + "required": []string{"code"}, + }, + "LoginResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "token": map[string]interface{}{ + "type": "string", + "description": "JWT token", + }, + "openid": map[string]interface{}{ + "type": "string", + "description": "WeChat OpenID", + }, + }, + }, + "CreateOTPRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "OTP name", + }, + "issuer": map[string]interface{}{ + "type": "string", + "description": "OTP issuer", + }, + "secret": map[string]interface{}{ + "type": "string", + "description": "OTP secret", + }, + "algorithm": map[string]interface{}{ + "type": "string", + "description": "OTP algorithm", + "enum": []string{"SHA1", "SHA256", "SHA512"}, + }, + "digits": map[string]interface{}{ + "type": "integer", + "description": "OTP digits", + "enum": []int{6, 8}, + }, + "period": map[string]interface{}{ + "type": "integer", + "description": "OTP period in seconds", + "enum": []int{30, 60}, + }, + }, + "required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"}, + }, + "OTP": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "OTP ID", + }, + "user_id": map[string]interface{}{ + "type": "string", + "description": "User ID", + }, + "name": map[string]interface{}{ + "type": "string", + "description": "OTP name", + }, + "issuer": map[string]interface{}{ + "type": "string", + "description": "OTP issuer", + }, + "algorithm": map[string]interface{}{ + "type": "string", + "description": "OTP algorithm", + }, + "digits": map[string]interface{}{ + "type": "integer", + "description": "OTP digits", + }, + "period": map[string]interface{}{ + "type": "integer", + "description": "OTP period in seconds", + }, + "created_at": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Creation time", + }, + "updated_at": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Last update time", + }, + }, + }, + "OTPCodeResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "OTP code", + }, + "expires_in": map[string]interface{}{ + "type": "integer", + "description": "Seconds until expiration", + }, + }, + }, + "VerifyOTPRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "OTP code to verify", + }, + }, + "required": []string{"code"}, + }, + "VerifyOTPResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "valid": map[string]interface{}{ + "type": "boolean", + "description": "Whether the code is valid", + }, + }, + }, + "UpdateOTPRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "OTP name", + }, + "issuer": map[string]interface{}{ + "type": "string", + "description": "OTP issuer", + }, + "algorithm": map[string]interface{}{ + "type": "string", + "description": "OTP algorithm", + "enum": []string{"SHA1", "SHA256", "SHA512"}, + }, + "digits": map[string]interface{}{ + "type": "integer", + "description": "OTP digits", + "enum": []int{6, 8}, + }, + "period": map[string]interface{}{ + "type": "integer", + "description": "OTP period in seconds", + "enum": []int{30, 60}, + }, + }, + }, + "ErrorResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "integer", + "description": "Error code", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "Error message", + }, + }, + }, + }, + "securityDefinitions": map[string]interface{}{ + "Bearer": map[string]interface{}{ + "type": "apiKey", + "name": "Authorization", + "in": "header", + "description": "JWT token with Bearer prefix", + }, + }, + } + + data, _ := json.MarshalIndent(swagger, "", " ") + return data +} + +// getPaths returns the API paths +func getPaths() map[string]interface{} { + return map[string]interface{}{ + "/login": map[string]interface{}{ + "post": map[string]interface{}{ + "summary": "Login with WeChat", + "description": "Login with WeChat authorization code", + "tags": []string{"auth"}, + "parameters": []map[string]interface{}{ + { + "name": "body", + "in": "body", + "description": "Login request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/LoginRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Successful login", + "schema": map[string]interface{}{ + "$ref": "#/definitions/LoginResponse", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/verify-token": map[string]interface{}{ + "post": map[string]interface{}{ + "summary": "Verify token", + "description": "Verify JWT token", + "tags": []string{"auth"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Token is valid", + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "valid": map[string]interface{}{ + "type": "boolean", + "description": "Whether the token is valid", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "List OTPs", + "description": "List all OTPs for the authenticated user", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "List of OTPs", + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "$ref": "#/definitions/OTP", + }, + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + "post": map[string]interface{}{ + "summary": "Create OTP", + "description": "Create a new OTP", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "body", + "in": "body", + "description": "OTP creation request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/CreateOTPRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP created", + "schema": map[string]interface{}{ + "$ref": "#/definitions/OTP", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp/{id}": map[string]interface{}{ + "put": map[string]interface{}{ + "summary": "Update OTP", + "description": "Update an existing OTP", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + { + "name": "body", + "in": "body", + "description": "OTP update request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/UpdateOTPRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP updated", + "schema": map[string]interface{}{ + "$ref": "#/definitions/OTP", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + "delete": map[string]interface{}{ + "summary": "Delete OTP", + "description": "Delete an existing OTP", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP deleted", + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "Success message", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp/{id}/code": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Get OTP code", + "description": "Get the current OTP code", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP code", + "schema": map[string]interface{}{ + "$ref": "#/definitions/OTPCodeResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp/{id}/verify": map[string]interface{}{ + "post": map[string]interface{}{ + "summary": "Verify OTP code", + "description": "Verify an OTP code", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + { + "name": "body", + "in": "body", + "description": "OTP verification request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/VerifyOTPRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP verification result", + "schema": map[string]interface{}{ + "$ref": "#/definitions/VerifyOTPResponse", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/health": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Health check", + "description": "Check if the API is healthy", + "tags": []string{"system"}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "API is healthy", + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "status": map[string]interface{}{ + "type": "string", + "description": "Health status", + }, + "time": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Current time", + }, + }, + }, + }, + }, + }, + }, + "/metrics": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Metrics", + "description": "Get application metrics", + "tags": []string{"system"}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Application metrics", + }, + }, + }, + }, + "/swagger.json": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Swagger JSON", + "description": "Get Swagger JSON", + "tags": []string{"system"}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Swagger JSON", + }, + }, + }, + }, + } +} + +// Handler returns an HTTP handler for Swagger JSON +func Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(SwaggerJSON()) + } +} diff --git a/go.mod b/go.mod index b0da37e..356e98a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module otpm -go 1.21.1 +go 1.23.0 + +toolchain go1.23.9 require ( github.com/go-sql-driver/mysql v1.8.1 @@ -14,18 +16,30 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.8 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/prometheus/client_golang v1.22.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -36,9 +50,12 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect + golang.org/x/crypto v0.38.0 // indirect golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect + google.golang.org/protobuf v1.36.5 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect diff --git a/go.sum b/go.sum index 1c01a15..7c2fd63 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -11,12 +15,21 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= +github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= +github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -35,6 +48,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= @@ -45,6 +60,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -52,6 +69,14 @@ github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= @@ -89,17 +114,28 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go new file mode 100644 index 0000000..0c7a286 --- /dev/null +++ b/handlers/auth_handler.go @@ -0,0 +1,147 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "time" + + "otpm/api" + "otpm/services" + + "github.com/golang-jwt/jwt" +) + +// AuthHandler handles authentication related requests +type AuthHandler struct { + authService *services.AuthService +} + +// NewAuthHandler creates a new AuthHandler +func NewAuthHandler(authService *services.AuthService) *AuthHandler { + return &AuthHandler{ + authService: authService, + } +} + +// LoginRequest represents a login request +type LoginRequest struct { + Code string `json:"code"` +} + +// LoginResponse represents a login response +type LoginResponse struct { + Token string `json:"token"` + OpenID string `json:"openid"` +} + +// Login handles WeChat login +func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Limit request body size to prevent DOS + r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request + + // Parse request + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + fmt.Sprintf("Invalid request body: %v", err)) + log.Printf("Login request parse error: %v", err) + return + } + + // Validate request + if req.Code == "" { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Code is required") + log.Printf("Login request validation failed: empty code") + return + } + + // Login with WeChat code + token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + log.Printf("Login failed for code %s: %v", req.Code, err) + return + } + + // Log successful login + log.Printf("Login successful for code %s (took %v)", + req.Code, time.Since(start)) + + // Return token + api.NewResponseWriter(w).WriteSuccess(LoginResponse{ + Token: token, + }) +} + +// VerifyToken handles token verification +func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Get token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Authorization header is required") + log.Printf("Token verification failed: missing Authorization header") + return + } + + // Validate token format + if len(authHeader) < 7 || authHeader[:7] != "Bearer " { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Invalid token format. Expected 'Bearer '") + log.Printf("Token verification failed: invalid token format") + return + } + + token := authHeader[7:] + if len(token) < 32 { // Basic length check + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Invalid token length") + log.Printf("Token verification failed: token too short") + return + } + + // Validate token + claims, err := h.authService.ValidateToken(token) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + log.Printf("Token verification failed for token %s: %v", + maskToken(token), err) // Mask token in logs + return + } + + // Log successful verification + userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string) + if !ok { + log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start)) + } else { + log.Printf("Token verified for user %s (took %v)", userID, time.Since(start)) + } + + // Token is valid + api.NewResponseWriter(w).WriteSuccess(map[string]bool{ + "valid": true, + }) +} + +// maskToken masks sensitive parts of token for logging +func maskToken(token string) string { + if len(token) < 8 { + return "****" + } + return token[:4] + "****" + token[len(token)-4:] +} + +// Routes returns all routes for the auth handler +func (h *AuthHandler) Routes() map[string]http.HandlerFunc { + return map[string]http.HandlerFunc{ + "/login": h.Login, + "/verify-token": h.VerifyToken, + } +} diff --git a/handlers/handler.go b/handlers/handler.go deleted file mode 100644 index 96e97c0..0000000 --- a/handlers/handler.go +++ /dev/null @@ -1,27 +0,0 @@ -package handlers - -import ( - "encoding/json" - "net/http" - - "github.com/jmoiron/sqlx" -) - -type Handler struct { - DB *sqlx.DB -} - -type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -func WriteJSON(w http.ResponseWriter, data interface{}, code int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - json.NewEncoder(w).Encode(data) -} -func WriteError(w http.ResponseWriter, message string, code int) { - WriteJSON(w, Response{Code: code, Message: message}, code) -} diff --git a/handlers/login.go b/handlers/login.go deleted file mode 100644 index f31c1c9..0000000 --- a/handlers/login.go +++ /dev/null @@ -1,158 +0,0 @@ -package handlers - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "time" - - "github.com/golang-jwt/jwt" - "github.com/spf13/viper" -) - -type LoginRequest struct { - Code string `json:"code"` -} - -// 封装code2session接口返回数据 -type LoginResponse struct { - OpenId string `json:"openid"` - SessionKey string `json:"session_key"` - UnionId string `json:"unionid"` - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` -} - -var wxClient = &http.Client{ - Timeout: 10 * time.Second, - Transport: &http.Transport{ - MaxIdleConnsPerHost: 10, - }, -} - -func getLoginResponse(code string) (*LoginResponse, error) { - appid := viper.GetString("wechat.appid") - secret := viper.GetString("wechat.secret") - url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code) - resp, err := wxClient.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var loginResponse LoginResponse - if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil { - return nil, err - } - - if loginResponse.ErrCode != 0 { - switch loginResponse.ErrCode { - case 40029: - return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg) - case 45011: - return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg) - default: - return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg) - } - } - return &loginResponse, nil -} - -func generateJWT(openid string) (string, error) { - tokenTTL := viper.GetDuration("auth.ttl") - if tokenTTL <= 0 { - tokenTTL = 24 * time.Hour - } - - secret := viper.GetString("auth.secret") - if secret == "" { - secret = "default_auth_secret_otpm" - } - - claims := jwt.MapClaims{ - "openid": openid, - "exp": time.Now().Add(tokenTTL).Unix(), - "iat": time.Now().Unix(), - "iss": viper.GetString("server.name"), - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(secret)) - if err != nil { - return "", err - } - return signedToken, nil -} - -func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { - var req LoginRequest - body, err := io.ReadAll(r.Body) - if err != nil { - WriteError(w, "Failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - if err := json.Unmarshal(body, &req); err != nil { - WriteError(w, "Failed to parse request body", http.StatusBadRequest) - return - } - - loginResponse, err := getLoginResponse(req.Code) - if err != nil { - switch { - case err.Error() == "invalid code": - WriteError(w, "Invalid code", http.StatusUnauthorized) - case err.Error() == "api limit exceeded": - WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests) - default: - WriteError(w, "Failed to get login response", http.StatusInternalServerError) - } - return - } - - // 插入或更新用户的openid和session_key - query := ` - INSERT INTO users (openid, session_key) - VALUES ($1, $2) - ON CONFLICT (openid) DO UPDATE SET session_key = $2 - RETURNING id; - ` - - var ID int - if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil { - WriteError(w, "Failed to log in user", http.StatusInternalServerError) - return - } - - token, err := generateJWT(loginResponse.OpenId) - if err != nil { - WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError) - return - } - - data := map[string]interface{}{ - "t": token, - "openid": loginResponse.OpenId, - } - - WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK) -} - -func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) { - userid := r.Context().Value("openid").(string) - - token, err := generateJWT(userid) - if err != nil { - WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError) - return - } - WriteJSON(w, Response{ - Code: 0, - Message: "Token refreshed successfully", - Data: map[string]string{ - "token": token, - }, - }, http.StatusOK) -} diff --git a/handlers/otp.go b/handlers/otp.go deleted file mode 100644 index c8a81cd..0000000 --- a/handlers/otp.go +++ /dev/null @@ -1,70 +0,0 @@ -package handlers - -import ( - "encoding/json" - "log" - "net/http" -) - -type OtpRequest struct { - OpenID string `json:"openid"` - Token *[]OTP `json:"token"` -} -type OTP struct { - Issuer string `json:"issuer"` - Remark string `json:"remark"` - Secret string `json:"secret"` -} - -func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) { - var req OtpRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - WriteError(w, "Failed to parse request body", http.StatusBadRequest) - return - } - - if req.OpenID == "" { - WriteError(w, "OpenID is required", http.StatusBadRequest) - return - } - - if req.Token == nil || len(*req.Token) == 0 { - WriteError(w, "Token is required", http.StatusBadRequest) - return - } - - log.Printf("Saving OTP for user: %s token count:: %d", req.OpenID, len(*req.Token)) - - // 插入或更新 OTP 记录 - query := ` - INSERT INTO otp (openid, token) - VALUES ($1, $2) - ON CONFLICT (openid) DO UPDATE SET token = EXCLUDED.token - ` - - _, err := h.DB.Exec(query, req.OpenID, req.Token) - if err != nil { - WriteError(w, "Failed to update or create OTP", http.StatusInternalServerError) - return - } - - WriteJSON(w, Response{Code: 0, Message: "Success"}, http.StatusOK) -} - -func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) { - openid := r.URL.Query().Get("openid") - if openid == "" { - WriteError(w, "OpenID is required", http.StatusBadRequest) - return - } - - var otp OtpRequest - - err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid) - if err != nil { - WriteError(w, "Failed to get OTP", http.StatusInternalServerError) - return - } - - WriteJSON(w, Response{Code: 0, Message: "Success", Data: otp}, http.StatusOK) -} diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go new file mode 100644 index 0000000..4361d50 --- /dev/null +++ b/handlers/otp_handler.go @@ -0,0 +1,286 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "time" + + "otpm/api" + "otpm/middleware" + "otpm/models" + "otpm/services" +) + +// OTPHandler handles OTP related requests +type OTPHandler struct { + otpService *services.OTPService +} + +// NewOTPHandler creates a new OTPHandler +func NewOTPHandler(otpService *services.OTPService) *OTPHandler { + return &OTPHandler{ + otpService: otpService, + } +} + +// CreateOTPRequest represents a request to create an OTP +type CreateOTPRequest struct { + Name string `json:"name"` + Issuer string `json:"issuer"` + Secret string `json:"secret"` + Algorithm string `json:"algorithm"` + Digits int `json:"digits"` + Period int `json:"period"` +} + +// CreateOTP handles OTP creation +func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Limit request body size + r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation + + // Get user ID from context + userID, err := middleware.GetUserID(r) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + log.Printf("CreateOTP unauthorized attempt") + return + } + + // Parse request + var req CreateOTPRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + fmt.Sprintf("Invalid request body: %v", err)) + log.Printf("CreateOTP request parse error for user %s: %v", userID, err) + return + } + + // Validate OTP parameters + if req.Secret == "" { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Secret is required") + log.Printf("CreateOTP validation failed for user %s: empty secret", userID) + return + } + + // Validate algorithm + supportedAlgos := map[string]bool{ + "SHA1": true, + "SHA256": true, + "SHA512": true, + } + if !supportedAlgos[strings.ToUpper(req.Algorithm)] { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Unsupported algorithm. Supported: SHA1, SHA256, SHA512") + log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s", + userID, req.Algorithm) + return + } + + // Create OTP + otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{ + Name: req.Name, + Issuer: req.Issuer, + Secret: req.Secret, + Algorithm: req.Algorithm, + Digits: req.Digits, + Period: req.Period, + }) + + if err != nil { + api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error())) + log.Printf("CreateOTP failed for user %s: %v", userID, err) + return + } + + // Log successful creation (mask secret in logs) + log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d", + userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period) + + api.NewResponseWriter(w).WriteSuccess(otp) +} + +// ListOTPs handles listing all OTPs for a user +func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) { + // Get user ID from context + userID, err := middleware.GetUserID(r) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Get OTPs + otps, err := h.otpService.ListOTPs(r.Context(), userID) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + api.NewResponseWriter(w).WriteSuccess(otps) +} + +// GetOTPCode handles generating OTP code +func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Get user ID from context + userID, err := middleware.GetUserID(r) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr) + return + } + + // Get OTP ID from URL + otpID := strings.TrimPrefix(r.URL.Path, "/otp/") + otpID = strings.TrimSuffix(otpID, "/code") + + // Validate OTP ID format + if len(otpID) != 36 { // Assuming UUID format + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Invalid OTP ID format") + log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID) + return + } + + // Rate limiting check could be added here + // (would require redis or similar rate limiter) + + // Generate code + code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err) + return + } + + // Log successful generation (without actual code) + log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)", + userID, otpID, time.Since(start), expiresIn) + + api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{ + "code": code, + "expires_in": expiresIn, + }) +} + +// VerifyOTPRequest represents a request to verify an OTP code +type VerifyOTPRequest struct { + Code string `json:"code"` +} + +// VerifyOTP handles OTP code verification +func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) { + // Get user ID from context + userID, err := middleware.GetUserID(r) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Get OTP ID from URL + otpID := strings.TrimPrefix(r.URL.Path, "/otp/") + otpID = strings.TrimSuffix(otpID, "/verify") + + // Parse request + var req VerifyOTPRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body") + return + } + + // Verify code + valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + api.NewResponseWriter(w).WriteSuccess(map[string]bool{ + "valid": valid, + }) +} + +// UpdateOTPRequest represents a request to update an OTP +type UpdateOTPRequest struct { + Name string `json:"name"` + Issuer string `json:"issuer"` + Algorithm string `json:"algorithm"` + Digits int `json:"digits"` + Period int `json:"period"` +} + +// UpdateOTP handles OTP update +func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) { + // Get user ID from context + userID, err := middleware.GetUserID(r) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Get OTP ID from URL + otpID := strings.TrimPrefix(r.URL.Path, "/otp/") + + // Parse request + var req UpdateOTPRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body") + return + } + + // Update OTP + otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{ + Name: req.Name, + Issuer: req.Issuer, + Algorithm: req.Algorithm, + Digits: req.Digits, + Period: req.Period, + }) + + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + api.NewResponseWriter(w).WriteSuccess(otp) +} + +// DeleteOTP handles OTP deletion +func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) { + // Get user ID from context + userID, err := middleware.GetUserID(r) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Get OTP ID from URL + otpID := strings.TrimPrefix(r.URL.Path, "/otp/") + + // Delete OTP + if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + api.NewResponseWriter(w).WriteSuccess(map[string]string{ + "message": "OTP deleted successfully", + }) +} + +// Routes returns all routes for the OTP handler +func (h *OTPHandler) Routes() map[string]http.HandlerFunc { + return map[string]http.HandlerFunc{ + "/otp": h.CreateOTP, + "/otp/": h.ListOTPs, + "/otp/{id}": h.UpdateOTP, + "/otp/{id}/code": h.GetOTPCode, + "/otp/{id}/verify": h.VerifyOTP, + } +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..d7d76e4 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,204 @@ +package logger + +import ( + "context" + "fmt" + "io" + "os" + "runtime" + "strings" + "time" + + "github.com/google/uuid" +) + +// Level represents a log level +type Level int + +const ( + // DEBUG level + DEBUG Level = iota + // INFO level + INFO + // WARN level + WARN + // ERROR level + ERROR + // FATAL level + FATAL +) + +// String returns the string representation of the log level +func (l Level) String() string { + switch l { + case DEBUG: + return "DEBUG" + case INFO: + return "INFO" + case WARN: + return "WARN" + case ERROR: + return "ERROR" + case FATAL: + return "FATAL" + default: + return "UNKNOWN" + } +} + +// Logger represents a logger +type Logger struct { + level Level + output io.Writer +} + +// contextKey is a type for context keys +type contextKey string + +// requestIDKey is the key for request ID in context +const requestIDKey = contextKey("request_id") + +// New creates a new logger +func New(level Level, output io.Writer) *Logger { + if output == nil { + output = os.Stdout + } + return &Logger{ + level: level, + output: output, + } +} + +// WithLevel creates a new logger with the specified level +func (l *Logger) WithLevel(level Level) *Logger { + return &Logger{ + level: level, + output: l.output, + } +} + +// WithOutput creates a new logger with the specified output +func (l *Logger) WithOutput(output io.Writer) *Logger { + return &Logger{ + level: l.level, + output: output, + } +} + +// log logs a message with the specified level +func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) { + if level < l.level { + return + } + + // Get request ID from context + requestID := getRequestID(ctx) + + // Get caller information + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "unknown" + line = 0 + } + // Extract just the filename + if idx := strings.LastIndex(file, "/"); idx >= 0 { + file = file[idx+1:] + } + + // Format message + message := fmt.Sprintf(format, args...) + + // Format log entry + timestamp := time.Now().Format(time.RFC3339) + logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n", + timestamp, level.String(), file, line, requestID, message) + + // Write log entry + _, _ = l.output.Write([]byte(logEntry)) + + // Exit if fatal + if level == FATAL { + os.Exit(1) + } +} + +// Debug logs a debug message +func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, DEBUG, format, args...) +} + +// Info logs an info message +func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, INFO, format, args...) +} + +// Warn logs a warning message +func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, WARN, format, args...) +} + +// Error logs an error message +func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, ERROR, format, args...) +} + +// Fatal logs a fatal message and exits +func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, FATAL, format, args...) +} + +// WithRequestID adds a request ID to the context +func WithRequestID(ctx context.Context) context.Context { + requestID := uuid.New().String() + return context.WithValue(ctx, requestIDKey, requestID) +} + +// GetRequestID gets the request ID from the context +func GetRequestID(ctx context.Context) string { + return getRequestID(ctx) +} + +// getRequestID gets the request ID from the context +func getRequestID(ctx context.Context) string { + if ctx == nil { + return "-" + } + requestID, ok := ctx.Value(requestIDKey).(string) + if !ok { + return "-" + } + return requestID +} + +// Default logger +var defaultLogger = New(INFO, os.Stdout) + +// SetDefaultLogger sets the default logger +func SetDefaultLogger(logger *Logger) { + defaultLogger = logger +} + +// Debug logs a debug message using the default logger +func Debug(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Debug(ctx, format, args...) +} + +// Info logs an info message using the default logger +func Info(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Info(ctx, format, args...) +} + +// Warn logs a warning message using the default logger +func Warn(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Warn(ctx, format, args...) +} + +// Error logs an error message using the default logger +func Error(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Error(ctx, format, args...) +} + +// Fatal logs a fatal message and exits using the default logger +func Fatal(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Fatal(ctx, format, args...) +} diff --git a/metrics/metrics.go b/metrics/metrics.go new file mode 100644 index 0000000..8e1d18a --- /dev/null +++ b/metrics/metrics.go @@ -0,0 +1,193 @@ +package metrics + +import ( + "fmt" + "net/http" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + // Default metrics + requestDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "Duration of HTTP requests in seconds", + Buckets: prometheus.DefBuckets, + }, + []string{"method", "path", "status"}, + ) + + requestTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "Total number of HTTP requests", + }, + []string{"method", "path", "status"}, + ) + + otpGenerationTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "otp_generation_total", + Help: "Total number of OTP generations", + }, + []string{"user_id", "otp_id"}, + ) + + otpVerificationTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "otp_verification_total", + Help: "Total number of OTP verifications", + }, + []string{"user_id", "otp_id", "success"}, + ) + + activeUsers = prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "active_users", + Help: "Number of active users", + }, + ) + + cacheHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cache_hits_total", + Help: "Total number of cache hits", + }, + []string{"cache"}, + ) + + cacheMisses = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cache_misses_total", + Help: "Total number of cache misses", + }, + []string{"cache"}, + ) +) + +func init() { + // Register metrics with prometheus + prometheus.MustRegister( + requestDuration, + requestTotal, + otpGenerationTotal, + otpVerificationTotal, + activeUsers, + cacheHits, + cacheMisses, + ) +} + +// MetricsService provides metrics functionality +type MetricsService struct { + activeUsersMutex sync.RWMutex + activeUserIDs map[string]bool +} + +// NewMetricsService creates a new MetricsService +func NewMetricsService() *MetricsService { + return &MetricsService{ + activeUserIDs: make(map[string]bool), + } +} + +// Handler returns an HTTP handler for metrics +func (s *MetricsService) Handler() http.Handler { + return promhttp.Handler() +} + +// RecordRequest records metrics for an HTTP request +func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) { + labels := prometheus.Labels{ + "method": method, + "path": path, + "status": fmt.Sprintf("%d", status), + } + + requestDuration.With(labels).Observe(duration.Seconds()) + requestTotal.With(labels).Inc() +} + +// RecordOTPGeneration records metrics for OTP generation +func (s *MetricsService) RecordOTPGeneration(userID, otpID string) { + otpGenerationTotal.With(prometheus.Labels{ + "user_id": userID, + "otp_id": otpID, + }).Inc() +} + +// RecordOTPVerification records metrics for OTP verification +func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) { + otpVerificationTotal.With(prometheus.Labels{ + "user_id": userID, + "otp_id": otpID, + "success": fmt.Sprintf("%t", success), + }).Inc() +} + +// RecordUserActivity records user activity +func (s *MetricsService) RecordUserActivity(userID string) { + s.activeUsersMutex.Lock() + defer s.activeUsersMutex.Unlock() + + if !s.activeUserIDs[userID] { + s.activeUserIDs[userID] = true + activeUsers.Inc() + } +} + +// RecordUserInactivity records user inactivity +func (s *MetricsService) RecordUserInactivity(userID string) { + s.activeUsersMutex.Lock() + defer s.activeUsersMutex.Unlock() + + if s.activeUserIDs[userID] { + delete(s.activeUserIDs, userID) + activeUsers.Dec() + } +} + +// RecordCacheHit records a cache hit +func (s *MetricsService) RecordCacheHit(cache string) { + cacheHits.With(prometheus.Labels{ + "cache": cache, + }).Inc() +} + +// RecordCacheMiss records a cache miss +func (s *MetricsService) RecordCacheMiss(cache string) { + cacheMisses.With(prometheus.Labels{ + "cache": cache, + }).Inc() +} + +// Middleware creates a middleware that records request metrics +func (s *MetricsService) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Create response writer that captures status code + rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} + + // Call next handler + next.ServeHTTP(rw, r) + + // Record metrics + s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start)) + }) +} + +// responseWriter wraps http.ResponseWriter to capture status code +type responseWriter struct { + http.ResponseWriter + status int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.status = code + rw.ResponseWriter.WriteHeader(code) +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..6afce13 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,353 @@ +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "log" + "math/rand" + "net/http" + "runtime/debug" + "strings" + "time" + + "github.com/golang-jwt/jwt" +) + +// Response represents a standard API response +type Response struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ErrorResponse sends a JSON error response +func ErrorResponse(w http.ResponseWriter, code int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(Response{ + Code: code, + Message: message, + }) +} + +// SuccessResponse sends a JSON success response +func SuccessResponse(w http.ResponseWriter, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(Response{ + Code: http.StatusOK, + Message: "success", + Data: data, + }) +} + +// Logger is a middleware that logs request details with structured format +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + r.Header.Set("X-Request-ID", requestID) + } + + // Create a custom response writer to capture status code + rw := &responseWriter{ + ResponseWriter: w, + status: http.StatusOK, + } + + // Process request + next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID))) + + // Log structured request details + log.Printf( + "method=%s path=%s status=%d duration=%s ip=%s request_id=%s", + r.Method, + r.URL.Path, + rw.status, + time.Since(start).String(), + r.RemoteAddr, + requestID, + ) + }) +} + +// generateRequestID creates a unique request identifier +func generateRequestID() string { + return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) +} + +// randomString generates a random string of given length +func randomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} + +// Recover is a middleware that recovers from panics with detailed logging +func Recover(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + // Get request ID from context + requestID := "" + if ctx := r.Context(); ctx != nil { + if id, ok := ctx.Value("request_id").(string); ok { + requestID = id + } + } + + // Log error with stack trace and request context + log.Printf( + "panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s", + err, + requestID, + r.Method, + r.URL.Path, + r.RemoteAddr, + debug.Stack(), + ) + + // Determine error type + var message string + var status int + + switch e := err.(type) { + case error: + message = e.Error() + if isClientError(e) { + status = http.StatusBadRequest + } else { + status = http.StatusInternalServerError + } + case string: + message = e + status = http.StatusInternalServerError + default: + message = "Internal Server Error" + status = http.StatusInternalServerError + } + + ErrorResponse(w, status, message) + } + }() + next.ServeHTTP(w, r) + }) +} + +// isClientError checks if error should be treated as client error +func isClientError(err error) bool { + // Add more client error types as needed + return strings.Contains(err.Error(), "validation") || + strings.Contains(err.Error(), "invalid") || + strings.Contains(err.Error(), "missing") +} + +// CORS is a middleware that handles CORS +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// Timeout is a middleware that safely handles request timeouts +func Timeout(duration time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), duration) + defer cancel() + + // Use buffered channels to prevent goroutine leaks + done := make(chan struct{}, 1) + panicChan := make(chan interface{}, 1) + + // Track request processing in goroutine + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + next.ServeHTTP(w, r.WithContext(ctx)) + done <- struct{}{} + }() + + // Wait for completion, timeout or panic + select { + case <-done: + return + case p := <-panicChan: + panic(p) // Re-throw panic to be caught by Recover middleware + case <-ctx.Done(): + // Get request context for logging + requestID := "" + if ctx := r.Context(); ctx != nil { + if id, ok := ctx.Value("request_id").(string); ok { + requestID = id + } + } + + // Log timeout details + log.Printf( + "request_timeout: request_id=%s method=%s path=%s timeout=%s", + requestID, + r.Method, + r.URL.Path, + duration.String(), + ) + + // Send timeout response + ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf( + "Request timed out after %s", duration.String(), + )) + } + }) + } +} + +// Auth is a middleware that validates JWT tokens with enhanced security +func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get request ID for logging + requestID := "" + if ctx := r.Context(); ctx != nil { + if id, ok := ctx.Value("request_id").(string); ok { + requestID = id + } + } + + // Get token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required") + return + } + + // Validate header format + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer '") + return + } + + tokenString := parts[1] + + // Parse and validate token + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Validate signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(jwtSecret), nil + }) + + if err != nil { + log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err) + ErrorResponse(w, http.StatusUnauthorized, "Invalid token") + return + } + + if !token.Valid { + log.Printf("auth_failed: request_id=%s error=invalid_token", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Invalid token") + return + } + + // Validate claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims") + return + } + + // Check required claims + userID, ok := claims["user_id"].(string) + if !ok || userID == "" { + log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token") + return + } + + // Check token expiration + if exp, ok := claims["exp"].(float64); ok { + if time.Now().Unix() > int64(exp) { + log.Printf("auth_failed: request_id=%s error=token_expired", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Token has expired") + return + } + } + + // Check required roles if specified + if len(requiredRoles) > 0 { + roles, ok := claims["roles"].([]interface{}) + if !ok { + log.Printf("auth_failed: request_id=%s error=missing_roles", requestID) + ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles") + return + } + + hasRequiredRole := false + for _, requiredRole := range requiredRoles { + for _, role := range roles { + if r, ok := role.(string); ok && r == requiredRole { + hasRequiredRole = true + break + } + } + } + + if !hasRequiredRole { + log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID) + ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions") + return + } + } + + // Add claims to context + ctx := r.Context() + ctx = context.WithValue(ctx, "user_id", userID) + ctx = context.WithValue(ctx, "claims", claims) + + log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// responseWriter is a custom response writer that captures the status code +type responseWriter struct { + http.ResponseWriter + status int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.status = code + rw.ResponseWriter.WriteHeader(code) +} + +// GetUserID gets the user ID from the request context +func GetUserID(r *http.Request) (string, error) { + userID, ok := r.Context().Value("user_id").(string) + if !ok { + return "", fmt.Errorf("user ID not found in context") + } + return userID, nil +} diff --git a/miniprogram-example/app.js b/miniprogram-example/app.js new file mode 100644 index 0000000..34eba08 --- /dev/null +++ b/miniprogram-example/app.js @@ -0,0 +1,50 @@ +// app.js +App({ + onLaunch() { + // 检查更新 + if (wx.canIUse('getUpdateManager')) { + const updateManager = wx.getUpdateManager(); + updateManager.onCheckForUpdate(function (res) { + if (res.hasUpdate) { + updateManager.onUpdateReady(function () { + wx.showModal({ + title: '更新提示', + content: '新版本已经准备好,是否重启应用?', + success: function (res) { + if (res.confirm) { + updateManager.applyUpdate(); + } + } + }); + }); + + updateManager.onUpdateFailed(function () { + wx.showModal({ + title: '更新提示', + content: '新版本下载失败,请检查网络后重试', + showCancel: false + }); + }); + } + }); + } + + // 获取系统信息 + try { + const systemInfo = wx.getSystemInfoSync(); + this.globalData.systemInfo = systemInfo; + + // 计算安全区域 + const { screenHeight, safeArea } = systemInfo; + this.globalData.safeAreaBottom = screenHeight - safeArea.bottom; + } catch (e) { + console.error('获取系统信息失败', e); + } + }, + + globalData: { + userInfo: null, + systemInfo: {}, + safeAreaBottom: 0 + } +}); \ No newline at end of file diff --git a/miniprogram-example/app.json b/miniprogram-example/app.json new file mode 100644 index 0000000..89f79ae --- /dev/null +++ b/miniprogram-example/app.json @@ -0,0 +1,23 @@ +{ + "pages": [ + "pages/login/login", + "pages/otp-list/index", + "pages/otp-add/index" + ], + "window": { + "backgroundTextStyle": "light", + "navigationBarBackgroundColor": "#fff", + "navigationBarTitleText": "OTPM", + "navigationBarTextStyle": "black", + "backgroundColor": "#F8F8F8" + }, + "permission": { + "scope.camera": { + "desc": "需要使用相机扫描二维码" + } + }, + "usingComponents": {}, + "style": "v2", + "sitemapLocation": "sitemap.json", + "lazyCodeLoading": "requiredComponents" +} \ No newline at end of file diff --git a/miniprogram-example/app.wxss b/miniprogram-example/app.wxss new file mode 100644 index 0000000..fb73b4a --- /dev/null +++ b/miniprogram-example/app.wxss @@ -0,0 +1,238 @@ +/**app.wxss**/ +page { + --primary-color: #1890ff; + --danger-color: #ff4d4f; + --success-color: #52c41a; + --warning-color: #faad14; + --text-color: #333333; + --text-color-secondary: #666666; + --text-color-light: #999999; + --border-color: #e8e8e8; + --background-color: #f8f8f8; + --border-radius: 8rpx; + --safe-area-bottom: env(safe-area-inset-bottom); + + font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica, + Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei', + sans-serif; + font-size: 28rpx; + line-height: 1.5; + color: var(--text-color); + background-color: var(--background-color); +} + +/* 清除默认样式 */ +button { + padding: 0; + margin: 0; + background: none; + border: none; + text-align: left; + line-height: inherit; + overflow: visible; +} + +button::after { + border: none; +} + +/* 通用样式类 */ +.container { + min-height: 100vh; + box-sizing: border-box; +} + +.safe-area-bottom { + padding-bottom: var(--safe-area-bottom); +} + +.flex-center { + display: flex; + align-items: center; + justify-content: center; +} + +.flex-between { + display: flex; + align-items: center; + justify-content: space-between; +} + +.flex-column { + display: flex; + flex-direction: column; +} + +.text-primary { + color: var(--primary-color); +} + +.text-danger { + color: var(--danger-color); +} + +.text-success { + color: var(--success-color); +} + +.text-warning { + color: var(--warning-color); +} + +.text-secondary { + color: var(--text-color-secondary); +} + +.text-light { + color: var(--text-color-light); +} + +.text-center { + text-align: center; +} + +.text-left { + text-align: left; +} + +.text-right { + text-align: right; +} + +.text-bold { + font-weight: bold; +} + +.text-ellipsis { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.bg-white { + background-color: #ffffff; +} + +.bg-primary { + background-color: var(--primary-color); +} + +.bg-danger { + background-color: var(--danger-color); +} + +.bg-success { + background-color: var(--success-color); +} + +.bg-warning { + background-color: var(--warning-color); +} + +.shadow { + box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); +} + +.rounded { + border-radius: var(--border-radius); +} + +.border { + border: 2rpx solid var(--border-color); +} + +.border-top { + border-top: 2rpx solid var(--border-color); +} + +.border-bottom { + border-bottom: 2rpx solid var(--border-color); +} + +/* 动画类 */ +.fade-in { + animation: fadeIn 0.3s ease-in-out; +} + +.fade-out { + animation: fadeOut 0.3s ease-in-out; +} + +@keyframes fadeIn { + from { + opacity: 0; + } + to { + opacity: 1; + } +} + +@keyframes fadeOut { + from { + opacity: 1; + } + to { + opacity: 0; + } +} + +/* 间距类 */ +.m-0 { margin: 0; } +.m-1 { margin: 10rpx; } +.m-2 { margin: 20rpx; } +.m-3 { margin: 30rpx; } +.m-4 { margin: 40rpx; } + +.mt-0 { margin-top: 0; } +.mt-1 { margin-top: 10rpx; } +.mt-2 { margin-top: 20rpx; } +.mt-3 { margin-top: 30rpx; } +.mt-4 { margin-top: 40rpx; } + +.mb-0 { margin-bottom: 0; } +.mb-1 { margin-bottom: 10rpx; } +.mb-2 { margin-bottom: 20rpx; } +.mb-3 { margin-bottom: 30rpx; } +.mb-4 { margin-bottom: 40rpx; } + +.ml-0 { margin-left: 0; } +.ml-1 { margin-left: 10rpx; } +.ml-2 { margin-left: 20rpx; } +.ml-3 { margin-left: 30rpx; } +.ml-4 { margin-left: 40rpx; } + +.mr-0 { margin-right: 0; } +.mr-1 { margin-right: 10rpx; } +.mr-2 { margin-right: 20rpx; } +.mr-3 { margin-right: 30rpx; } +.mr-4 { margin-right: 40rpx; } + +.p-0 { padding: 0; } +.p-1 { padding: 10rpx; } +.p-2 { padding: 20rpx; } +.p-3 { padding: 30rpx; } +.p-4 { padding: 40rpx; } + +.pt-0 { padding-top: 0; } +.pt-1 { padding-top: 10rpx; } +.pt-2 { padding-top: 20rpx; } +.pt-3 { padding-top: 30rpx; } +.pt-4 { padding-top: 40rpx; } + +.pb-0 { padding-bottom: 0; } +.pb-1 { padding-bottom: 10rpx; } +.pb-2 { padding-bottom: 20rpx; } +.pb-3 { padding-bottom: 30rpx; } +.pb-4 { padding-bottom: 40rpx; } + +.pl-0 { padding-left: 0; } +.pl-1 { padding-left: 10rpx; } +.pl-2 { padding-left: 20rpx; } +.pl-3 { padding-left: 30rpx; } +.pl-4 { padding-left: 40rpx; } + +.pr-0 { padding-right: 0; } +.pr-1 { padding-right: 10rpx; } +.pr-2 { padding-right: 20rpx; } +.pr-3 { padding-right: 30rpx; } +.pr-4 { padding-right: 40rpx; } \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.js b/miniprogram-example/pages/login/login.js new file mode 100644 index 0000000..e4ad173 --- /dev/null +++ b/miniprogram-example/pages/login/login.js @@ -0,0 +1,48 @@ +// login.js +import { wxLogin } from '../../services/auth'; + +Page({ + data: { + loading: false + }, + + onLoad() { + // 页面加载时检查是否已经登录 + const token = wx.getStorageSync('token'); + if (token) { + this.redirectToHome(); + } + }, + + // 处理登录按钮点击 + handleLogin() { + if (this.data.loading) return; + + this.setData({ loading: true }); + + wxLogin() + .then(() => { + wx.showToast({ + title: '登录成功', + icon: 'success' + }); + this.redirectToHome(); + }) + .catch(err => { + wx.showToast({ + title: err.message || '登录失败', + icon: 'none' + }); + }) + .finally(() => { + this.setData({ loading: false }); + }); + }, + + // 跳转到首页 + redirectToHome() { + wx.reLaunch({ + url: '/pages/otp-list/index' + }); + } +}); \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.json b/miniprogram-example/pages/login/login.json new file mode 100644 index 0000000..8835af0 --- /dev/null +++ b/miniprogram-example/pages/login/login.json @@ -0,0 +1,3 @@ +{ + "usingComponents": {} +} \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.wxml b/miniprogram-example/pages/login/login.wxml new file mode 100644 index 0000000..d73a711 --- /dev/null +++ b/miniprogram-example/pages/login/login.wxml @@ -0,0 +1,30 @@ + + + + + OTPM 小程序 + + + + \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.wxss b/miniprogram-example/pages/login/login.wxss new file mode 100644 index 0000000..26cece6 --- /dev/null +++ b/miniprogram-example/pages/login/login.wxss @@ -0,0 +1,97 @@ +/* login.wxss */ +.container { + display: flex; + flex-direction: column; + align-items: center; + justify-content: space-between; + height: 100vh; + padding: 60rpx 40rpx; + box-sizing: border-box; + background-color: #f8f8f8; +} + +.logo-container { + display: flex; + flex-direction: column; + align-items: center; + margin-top: 80rpx; +} + +.logo { + width: 180rpx; + height: 180rpx; + margin-bottom: 20rpx; +} + +.app-name { + font-size: 36rpx; + font-weight: bold; + color: #333; +} + +.login-container { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; + margin-bottom: 100rpx; +} + +.login-title { + font-size: 48rpx; + font-weight: bold; + color: #333; + margin-bottom: 20rpx; +} + +.login-subtitle { + font-size: 28rpx; + color: #666; + margin-bottom: 80rpx; +} + +.login-button { + width: 80%; + height: 88rpx; + border-radius: 44rpx; + font-size: 32rpx; + display: flex; + align-items: center; + justify-content: center; +} + +.login-button.loading { + background-color: #8cc4ff; +} + +.loading-container { + display: flex; + align-items: center; + justify-content: center; +} + +.loading-icon { + width: 36rpx; + height: 36rpx; + margin-right: 10rpx; + border: 4rpx solid #ffffff; + border-radius: 50%; + border-top-color: transparent; + animation: spin 1s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +.privacy-policy { + margin-top: 40rpx; + font-size: 24rpx; + color: #999; +} + +.policy-link { + color: #1890ff; + display: inline; +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.js b/miniprogram-example/pages/otp-add/index.js new file mode 100644 index 0000000..e89c623 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.js @@ -0,0 +1,169 @@ +// otp-add/index.js +import { createOTP } from '../../services/otp'; + +Page({ + data: { + form: { + name: '', + issuer: '', + secret: '', + algorithm: 'SHA1', + digits: 6, + period: 30 + }, + algorithms: ['SHA1', 'SHA256', 'SHA512'], + digitOptions: [6, 8], + periodOptions: [30, 60], + submitting: false, + scanMode: false + }, + + // 处理输入变化 + handleInputChange(e) { + const { field } = e.currentTarget.dataset; + const { value } = e.detail; + + this.setData({ + [`form.${field}`]: value + }); + }, + + // 处理选择器变化 + handlePickerChange(e) { + const { field } = e.currentTarget.dataset; + const { value } = e.detail; + + const options = this.data[`${field}Options`] || this.data[field]; + const selectedValue = options[value]; + + this.setData({ + [`form.${field}`]: selectedValue + }); + }, + + // 扫描二维码 + handleScanQRCode() { + this.setData({ scanMode: true }); + + wx.scanCode({ + scanType: ['qrCode'], + success: (res) => { + try { + // 解析otpauth://协议的URL + const url = res.result; + if (url.startsWith('otpauth://totp/')) { + const parsedUrl = new URL(url); + const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠 + + // 解析路径中的issuer和name + let issuer = ''; + let name = path; + + if (path.includes(':')) { + const parts = path.split(':'); + issuer = parts[0]; + name = parts[1]; + } + + // 从查询参数中获取其他信息 + const secret = parsedUrl.searchParams.get('secret') || ''; + const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1'; + const digits = parseInt(parsedUrl.searchParams.get('digits') || '6'); + const period = parseInt(parsedUrl.searchParams.get('period') || '30'); + + // 如果查询参数中有issuer,优先使用 + if (parsedUrl.searchParams.get('issuer')) { + issuer = parsedUrl.searchParams.get('issuer'); + } + + this.setData({ + form: { + name, + issuer, + secret, + algorithm, + digits, + period + } + }); + + wx.showToast({ + title: '二维码解析成功', + icon: 'success' + }); + } else { + wx.showToast({ + title: '不支持的二维码格式', + icon: 'none' + }); + } + } catch (err) { + wx.showToast({ + title: '二维码解析失败', + icon: 'none' + }); + } + }, + fail: () => { + wx.showToast({ + title: '扫描取消', + icon: 'none' + }); + }, + complete: () => { + this.setData({ scanMode: false }); + } + }); + }, + + // 提交表单 + handleSubmit() { + const { form } = this.data; + + // 表单验证 + if (!form.name) { + wx.showToast({ + title: '请输入名称', + icon: 'none' + }); + return; + } + + if (!form.secret) { + wx.showToast({ + title: '请输入密钥', + icon: 'none' + }); + return; + } + + this.setData({ submitting: true }); + + createOTP(form) + .then(() => { + wx.showToast({ + title: '添加成功', + icon: 'success' + }); + + // 返回上一页 + setTimeout(() => { + wx.navigateBack(); + }, 1500); + }) + .catch(err => { + wx.showToast({ + title: err.message || '添加失败', + icon: 'none' + }); + }) + .finally(() => { + this.setData({ submitting: false }); + }); + }, + + // 取消 + handleCancel() { + wx.navigateBack(); + } +}); \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.json b/miniprogram-example/pages/otp-add/index.json new file mode 100644 index 0000000..8835af0 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.json @@ -0,0 +1,3 @@ +{ + "usingComponents": {} +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.wxml b/miniprogram-example/pages/otp-add/index.wxml new file mode 100644 index 0000000..e5a2b01 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.wxml @@ -0,0 +1,119 @@ + + + + 添加OTP + + + + + 名称 * + + + + + 发行方 + + + + + 密钥 * + + + + 🔍 + + + + + + + + + 算法 + + + {{form.algorithm}} + + + + + + + + 位数 + + + {{form.digits}} + + + + + + + 周期(秒) + + + {{form.period}} + + + + + + + + + + + + + \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.wxss b/miniprogram-example/pages/otp-add/index.wxss new file mode 100644 index 0000000..3906a97 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.wxss @@ -0,0 +1,176 @@ +/* otp-add/index.wxss */ +.container { + min-height: 100vh; + background-color: #f8f8f8; + padding: 0 0 40rpx 0; +} + +.header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 40rpx 32rpx; + background-color: #ffffff; + position: sticky; + top: 0; + z-index: 100; + box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05); +} + +.title { + font-size: 36rpx; + font-weight: bold; + color: #333333; +} + +.form-container { + background-color: #ffffff; + padding: 32rpx; + margin: 32rpx; + border-radius: 16rpx; + box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); +} + +.form-group { + margin-bottom: 32rpx; +} + +.form-row { + display: flex; + justify-content: space-between; +} + +.form-group.half { + width: 48%; +} + +.form-label { + display: block; + font-size: 28rpx; + color: #666666; + margin-bottom: 12rpx; +} + +.required { + color: #ff4d4f; +} + +.form-input { + width: 100%; + height: 80rpx; + background-color: #f5f5f5; + border-radius: 8rpx; + padding: 0 24rpx; + font-size: 28rpx; + color: #333333; + box-sizing: border-box; +} + +.secret-input-container { + position: relative; +} + +.scan-button { + position: absolute; + right: 20rpx; + top: 50%; + transform: translateY(-50%); + width: 60rpx; + height: 60rpx; + display: flex; + align-items: center; + justify-content: center; +} + +.scan-icon { + font-size: 40rpx; + color: #1890ff; +} + +.scanning-indicator { + position: absolute; + right: 20rpx; + top: 50%; + transform: translateY(-50%); + width: 40rpx; + height: 40rpx; +} + +.scanning-spinner { + width: 40rpx; + height: 40rpx; + border: 4rpx solid #f3f3f3; + border-top: 4rpx solid #1890ff; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +.picker-view { + width: 100%; + height: 80rpx; + background-color: #f5f5f5; + border-radius: 8rpx; + padding: 0 24rpx; + font-size: 28rpx; + color: #333333; + display: flex; + align-items: center; + justify-content: space-between; + box-sizing: border-box; +} + +.picker-arrow { + font-size: 24rpx; + color: #999999; +} + +.button-group { + display: flex; + justify-content: space-between; + padding: 32rpx; +} + +.cancel-button, .submit-button { + width: 48%; + height: 88rpx; + border-radius: 44rpx; + font-size: 32rpx; + display: flex; + align-items: center; + justify-content: center; +} + +.cancel-button { + background-color: #f5f5f5; + color: #666666; +} + +.submit-button { + background-color: #1890ff; + color: #ffffff; +} + +.submit-button.loading { + background-color: #8cc4ff; +} + +.loading-container { + display: flex; + align-items: center; + justify-content: center; +} + +.loading-icon { + width: 36rpx; + height: 36rpx; + margin-right: 10rpx; + border: 4rpx solid #ffffff; + border-radius: 50%; + border-top-color: transparent; + animation: spin 1s linear infinite; +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.js b/miniprogram-example/pages/otp-list/index.js new file mode 100644 index 0000000..9ed6fa1 --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.js @@ -0,0 +1,213 @@ +// otp-list/index.js +import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp'; +import { checkLoginStatus } from '../../services/auth'; + +Page({ + data: { + otpList: [], + loading: true, + refreshing: false + }, + + onLoad() { + this.checkLogin(); + }, + + onShow() { + // 每次页面显示时刷新OTP列表 + if (!this.data.loading) { + this.fetchOTPList(); + } + }, + + // 下拉刷新 + onPullDownRefresh() { + this.setData({ refreshing: true }); + this.fetchOTPList().finally(() => { + wx.stopPullDownRefresh(); + this.setData({ refreshing: false }); + }); + }, + + // 检查登录状态 + checkLogin() { + checkLoginStatus().then(isLoggedIn => { + if (isLoggedIn) { + this.fetchOTPList(); + } else { + wx.redirectTo({ + url: '/pages/login/login' + }); + } + }); + }, + + // 获取OTP列表 + fetchOTPList() { + this.setData({ loading: true }); + return getOTPList() + .then(res => { + if (res.data && Array.isArray(res.data)) { + this.setData({ + otpList: res.data, + loading: false + }); + + // 获取每个OTP的当前验证码 + this.refreshOTPCodes(); + } + }) + .catch(err => { + wx.showToast({ + title: '获取OTP列表失败', + icon: 'none' + }); + this.setData({ loading: false }); + }); + }, + + // 刷新所有OTP的验证码 + refreshOTPCodes() { + const { otpList } = this.data; + + // 为每个OTP获取当前验证码 + const promises = otpList.map(otp => { + return getOTPCode(otp.id) + .then(res => { + if (res.data && res.data.code) { + return { + id: otp.id, + code: res.data.code, + expiresIn: res.data.expires_in || 30 + }; + } + return null; + }) + .catch(() => null); + }); + + Promise.all(promises).then(results => { + const updatedList = [...this.data.otpList]; + + results.forEach(result => { + if (result) { + const index = updatedList.findIndex(otp => otp.id === result.id); + if (index !== -1) { + updatedList[index] = { + ...updatedList[index], + currentCode: result.code, + expiresIn: result.expiresIn + }; + } + } + }); + + this.setData({ otpList: updatedList }); + + // 设置定时器,每秒更新倒计时 + this.startCountdown(); + }); + }, + + // 开始倒计时 + startCountdown() { + // 清除之前的定时器 + if (this.countdownTimer) { + clearInterval(this.countdownTimer); + } + + // 创建新的定时器,每秒更新一次 + this.countdownTimer = setInterval(() => { + const { otpList } = this.data; + let needRefresh = false; + + const updatedList = otpList.map(otp => { + if (!otp.countdown) { + otp.countdown = otp.expiresIn || 30; + } + + otp.countdown -= 1; + + // 如果倒计时结束,标记需要刷新 + if (otp.countdown <= 0) { + needRefresh = true; + } + + return otp; + }); + + this.setData({ otpList: updatedList }); + + // 如果有OTP需要刷新,重新获取验证码 + if (needRefresh) { + this.refreshOTPCodes(); + } + }, 1000); + }, + + // 添加新的OTP + handleAddOTP() { + wx.navigateTo({ + url: '/pages/otp-add/index' + }); + }, + + // 编辑OTP + handleEditOTP(e) { + const { id } = e.currentTarget.dataset; + wx.navigateTo({ + url: `/pages/otp-edit/index?id=${id}` + }); + }, + + // 删除OTP + handleDeleteOTP(e) { + const { id, name } = e.currentTarget.dataset; + + wx.showModal({ + title: '确认删除', + content: `确定要删除 ${name} 吗?`, + confirmColor: '#ff4d4f', + success: (res) => { + if (res.confirm) { + deleteOTP(id) + .then(() => { + wx.showToast({ + title: '删除成功', + icon: 'success' + }); + this.fetchOTPList(); + }) + .catch(err => { + wx.showToast({ + title: '删除失败', + icon: 'none' + }); + }); + } + } + }); + }, + + // 复制验证码 + handleCopyCode(e) { + const { code } = e.currentTarget.dataset; + + wx.setClipboardData({ + data: code, + success: () => { + wx.showToast({ + title: '验证码已复制', + icon: 'success' + }); + } + }); + }, + + onUnload() { + // 页面卸载时清除定时器 + if (this.countdownTimer) { + clearInterval(this.countdownTimer); + } + } +}); \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.json b/miniprogram-example/pages/otp-list/index.json new file mode 100644 index 0000000..8835af0 --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.json @@ -0,0 +1,3 @@ +{ + "usingComponents": {} +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.wxml b/miniprogram-example/pages/otp-list/index.wxml new file mode 100644 index 0000000..27cb5ea --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.wxml @@ -0,0 +1,59 @@ + + + + 我的OTP列表 + + + + + + + + + + 加载中... + + + + + + + + + {{item.name}} + {{item.issuer}} + + + + {{item.currentCode || '******'}} + 点击复制 + + + + + {{item.countdown || 0}}s + + + + + + + + + + + + + + + + + + 暂无OTP,点击右上角添加 + + + \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.wxss b/miniprogram-example/pages/otp-list/index.wxss new file mode 100644 index 0000000..40436d5 --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.wxss @@ -0,0 +1,201 @@ +/* otp-list/index.wxss */ +.container { + min-height: 100vh; + background-color: #f8f8f8; + padding: 0 0 40rpx 0; +} + +.header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 40rpx 32rpx; + background-color: #ffffff; + position: sticky; + top: 0; + z-index: 100; + box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05); +} + +.title { + font-size: 36rpx; + font-weight: bold; + color: #333333; +} + +.add-button { + width: 64rpx; + height: 64rpx; + border-radius: 32rpx; + background-color: #1890ff; + display: flex; + align-items: center; + justify-content: center; + box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3); +} + +.add-icon { + color: #ffffff; + font-size: 40rpx; + line-height: 1; +} + +/* 加载状态 */ +.loading-container { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 120rpx 0; +} + +.loading-spinner { + width: 64rpx; + height: 64rpx; + border: 6rpx solid #f3f3f3; + border-top: 6rpx solid #1890ff; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +.loading-text { + margin-top: 20rpx; + font-size: 28rpx; + color: #999999; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* OTP列表 */ +.otp-list { + padding: 20rpx 32rpx; +} + +.otp-item { + background-color: #ffffff; + border-radius: 16rpx; + padding: 32rpx; + margin-bottom: 20rpx; + display: flex; + justify-content: space-between; + box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); +} + +.otp-info { + flex: 1; + margin-right: 20rpx; +} + +.otp-name-row { + display: flex; + align-items: center; + margin-bottom: 16rpx; +} + +.otp-name { + font-size: 32rpx; + font-weight: bold; + color: #333333; + margin-right: 16rpx; +} + +.otp-issuer { + font-size: 24rpx; + color: #666666; + background-color: #f5f5f5; + padding: 4rpx 12rpx; + border-radius: 8rpx; +} + +.otp-code-row { + display: flex; + align-items: center; + margin-bottom: 20rpx; +} + +.otp-code { + font-size: 44rpx; + font-family: monospace; + font-weight: bold; + color: #1890ff; + letter-spacing: 4rpx; + margin-right: 16rpx; +} + +.copy-hint { + font-size: 24rpx; + color: #999999; +} + +.otp-countdown { + position: relative; + width: 100%; +} + +.countdown-text { + position: absolute; + right: 0; + top: -30rpx; + font-size: 24rpx; + color: #999999; +} + +/* OTP操作按钮 */ +.otp-actions { + display: flex; + flex-direction: column; + justify-content: space-between; +} + +.action-button { + width: 56rpx; + height: 56rpx; + border-radius: 28rpx; + display: flex; + align-items: center; + justify-content: center; + margin-bottom: 16rpx; +} + +.action-button.edit { + background-color: #f0f7ff; +} + +.action-button.delete { + background-color: #fff1f0; +} + +.action-icon { + font-size: 32rpx; +} + +.edit .action-icon { + color: #1890ff; +} + +.delete .action-icon { + color: #ff4d4f; +} + +/* 空状态 */ +.empty-state { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 120rpx 0; +} + +.empty-image { + width: 240rpx; + height: 240rpx; + margin-bottom: 32rpx; +} + +.empty-text { + font-size: 28rpx; + color: #999999; +} \ No newline at end of file diff --git a/miniprogram-example/project.config.json b/miniprogram-example/project.config.json new file mode 100644 index 0000000..3fb79ad --- /dev/null +++ b/miniprogram-example/project.config.json @@ -0,0 +1,47 @@ +{ + "description": "项目配置文件", + "packOptions": { + "ignore": [], + "include": [] + }, + "miniprogramRoot": "", + "compileType": "miniprogram", + "projectname": "OTPM", + "setting": { + "useCompilerPlugins": [ + "sass" + ], + "babelSetting": { + "ignore": [], + "disablePlugins": [], + "outputPath": "" + }, + "es6": true, + "enhance": true, + "minified": true, + "postcss": true, + "minifyWXSS": true, + "minifyWXML": true, + "uglifyFileName": true, + "packNpmManually": false, + "packNpmRelationList": [], + "ignoreUploadUnusedFiles": true, + "compileWorklet": false, + "uploadWithSourceMap": true, + "localPlugins": false, + "disableUseStrict": false, + "condition": false, + "swc": false, + "disableSWC": true + }, + "simulatorType": "wechat", + "simulatorPluginLibVersion": {}, + "condition": {}, + "srcMiniprogramRoot": "", + "appid": "wxb6599459668b6b55", + "libVersion": "2.30.2", + "editorSetting": { + "tabIndent": "insertSpaces", + "tabSize": 2 + } +} \ No newline at end of file diff --git a/miniprogram-example/project.private.config.json b/miniprogram-example/project.private.config.json new file mode 100644 index 0000000..6b2a738 --- /dev/null +++ b/miniprogram-example/project.private.config.json @@ -0,0 +1,23 @@ +{ + "libVersion": "3.8.5", + "projectname": "OTPM", + "condition": {}, + "setting": { + "urlCheck": true, + "coverView": false, + "lazyloadPlaceholderEnable": false, + "skylineRenderEnable": false, + "preloadBackgroundData": false, + "autoAudits": false, + "useApiHook": true, + "useApiHostProcess": true, + "showShadowRootInWxmlPanel": false, + "useStaticServer": false, + "useLanDebug": false, + "showES6CompileOption": false, + "compileHotReLoad": true, + "checkInvalidKey": true, + "ignoreDevUnusedFiles": true, + "bigPackageSizeSupport": false + } +} \ No newline at end of file diff --git a/miniprogram-example/services/auth.js b/miniprogram-example/services/auth.js new file mode 100644 index 0000000..4c95fa4 --- /dev/null +++ b/miniprogram-example/services/auth.js @@ -0,0 +1,84 @@ +// auth.js - 认证相关服务 + +import request from '../utils/request'; + +/** + * 微信登录 + * 1. 调用wx.login获取code + * 2. 发送code到服务端换取token和openid + * 3. 保存token和openid到本地存储 + */ +export const wxLogin = () => { + return new Promise((resolve, reject) => { + wx.login({ + success: (res) => { + if (res.code) { + // 发送code到服务端 + request({ + url: '/login', + method: 'POST', + data: { + code: res.code + } + }).then(response => { + // 保存token和openid + if (response.data && response.data.token && response.data.openid) { + wx.setStorageSync('token', response.data.token); + wx.setStorageSync('openid', response.data.openid); + resolve(response.data); + } else { + reject(new Error('登录失败,服务器返回数据格式错误')); + } + }).catch(err => { + reject(err); + }); + } else { + reject(new Error('登录失败,获取code失败: ' + res.errMsg)); + } + }, + fail: (err) => { + reject(new Error('微信登录失败: ' + err.errMsg)); + } + }); + }); +}; + +/** + * 检查登录状态 + * 1. 检查本地是否有token和openid + * 2. 如果有,验证token是否有效 + * 3. 如果无效,清除本地存储并返回false + */ +export const checkLoginStatus = () => { + return new Promise((resolve, reject) => { + const token = wx.getStorageSync('token'); + const openid = wx.getStorageSync('openid'); + + if (!token || !openid) { + resolve(false); + return; + } + + // 验证token有效性 + request({ + url: '/verify-token', + method: 'POST' + }).then(() => { + resolve(true); + }).catch(() => { + // token无效,清除本地存储 + wx.removeStorageSync('token'); + wx.removeStorageSync('openid'); + resolve(false); + }); + }); +}; + +/** + * 退出登录 + */ +export const logout = () => { + wx.removeStorageSync('token'); + wx.removeStorageSync('openid'); + return Promise.resolve(); +}; \ No newline at end of file diff --git a/miniprogram-example/services/otp.js b/miniprogram-example/services/otp.js new file mode 100644 index 0000000..ab12da1 --- /dev/null +++ b/miniprogram-example/services/otp.js @@ -0,0 +1,87 @@ +// otp.js - OTP相关服务 + +import request from '../utils/request'; + +/** + * 创建新的OTP + * @param {Object} params - 创建OTP的参数 + * @param {string} params.name - OTP名称 + * @param {string} params.issuer - 发行方 + * @param {string} params.secret - 密钥 + * @param {string} params.algorithm - 算法,默认为SHA1 + * @param {number} params.digits - 位数,默认为6 + * @param {number} params.period - 周期,默认为30秒 + * @returns {Promise} - 返回创建结果 + */ +export const createOTP = (params) => { + return request({ + url: '/otp', + method: 'POST', + data: params + }); +}; + +/** + * 获取用户所有OTP列表 + * @returns {Promise} - 返回OTP列表 + */ +export const getOTPList = () => { + return request({ + url: '/otp', + method: 'GET' + }); +}; + +/** + * 获取指定OTP的当前验证码 + * @param {string} id - OTP的ID + * @returns {Promise} - 返回当前验证码 + */ +export const getOTPCode = (id) => { + return request({ + url: `/otp/${id}/code`, + method: 'GET' + }); +}; + +/** + * 验证OTP + * @param {string} id - OTP的ID + * @param {string} code - 用户输入的验证码 + * @returns {Promise} - 返回验证结果 + */ +export const verifyOTP = (id, code) => { + return request({ + url: `/otp/${id}/verify`, + method: 'POST', + data: { + code: code + } + }); +}; + +/** + * 更新OTP信息 + * @param {string} id - OTP的ID + * @param {Object} params - 更新的参数 + * @returns {Promise} - 返回更新结果 + */ +export const updateOTP = (id, params) => { + return request({ + url: `/otp/${id}`, + method: 'PUT', + data: params + }); +}; + +/** + * 删除OTP + * @param {string} id - OTP的ID + * @returns {Promise} - 返回删除结果 + */ +export const deleteOTP = (id) => { + return request({ + url: `/otp/${id}`, + method: 'DELETE' + }); +}; \ No newline at end of file diff --git a/miniprogram-example/utils/request.js b/miniprogram-example/utils/request.js new file mode 100644 index 0000000..b739a88 --- /dev/null +++ b/miniprogram-example/utils/request.js @@ -0,0 +1,64 @@ +// request.js - 网络请求工具类 + +const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名 + +// 请求拦截器 +const request = (options) => { + return new Promise((resolve, reject) => { + const token = wx.getStorageSync('token'); + const header = { + 'Content-Type': 'application/json', + ...options.header + }; + + // 如果有token,添加到请求头 + if (token) { + header['Authorization'] = `Bearer ${token}`; + } + + wx.request({ + url: `${BASE_URL}${options.url}`, + method: options.method || 'GET', + data: options.data, + header: header, + success: (res) => { + // 处理业务错误 + if (res.data.code !== 0) { + // token过期,尝试刷新 + if (res.statusCode === 401) { + refreshToken().then(() => { + // 刷新token后重试请求 + request(options).then(resolve).catch(reject); + }).catch((err) => { + // 刷新失败,需要重新登录 + wx.removeStorageSync('token'); + wx.removeStorageSync('openid'); + reject(err); + }); + return; + } + reject(new Error(res.data.message || '请求失败')); + return; + } + resolve(res.data); + }, + fail: reject + }); + }); +}; + +// 刷新token +const refreshToken = () => { + return request({ + url: '/refresh-token', + method: 'POST' + }).then(res => { + if (res.data && res.data.token) { + wx.setStorageSync('token', res.data.token); + return res.data.token; + } + throw new Error('Failed to refresh token'); + }); +}; + +export default request; \ No newline at end of file diff --git a/models/otp.go b/models/otp.go new file mode 100644 index 0000000..a8e48bc --- /dev/null +++ b/models/otp.go @@ -0,0 +1,195 @@ +package models + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/jmoiron/sqlx" +) + +// OTP represents a TOTP configuration +type OTP struct { + ID string `db:"id" json:"id"` + UserID string `db:"user_id" json:"user_id"` + Name string `db:"name" json:"name"` + Issuer string `db:"issuer" json:"issuer"` + Secret string `db:"secret" json:"-"` // Never expose secret in JSON + Algorithm string `db:"algorithm" json:"algorithm"` + Digits int `db:"digits" json:"digits"` + Period int `db:"period" json:"period"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +// OTPParams represents common OTP parameters used in creation and update +type OTPParams struct { + Name string + Issuer string + Secret string + Algorithm string + Digits int + Period int +} + +// OTPRepository handles OTP data operations +type OTPRepository struct { + db *sqlx.DB +} + +// NewOTPRepository creates a new OTPRepository +func NewOTPRepository(db *sqlx.DB) *OTPRepository { + return &OTPRepository{db: db} +} + +// FindByID finds an OTP by ID and user ID +func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) { + var otp OTP + query := `SELECT * FROM otps WHERE id = ? AND user_id = ?` + err := r.db.GetContext(ctx, &otp, query, id, userID) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("otp not found: %w", err) + } + return nil, fmt.Errorf("failed to find otp: %w", err) + } + return &otp, nil +} + +// FindAllByUserID finds all OTPs for a user +func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) { + var otps []*OTP + query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC` + err := r.db.SelectContext(ctx, &otps, query, userID) + if err != nil { + return nil, fmt.Errorf("failed to find otps: %w", err) + } + return otps, nil +} + +// Create creates a new OTP +func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error { + query := ` + INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + now := time.Now() + otp.CreatedAt = now + otp.UpdatedAt = now + + _, err := r.db.ExecContext( + ctx, + query, + otp.ID, + otp.UserID, + otp.Name, + otp.Issuer, + otp.Secret, + otp.Algorithm, + otp.Digits, + otp.Period, + otp.CreatedAt, + otp.UpdatedAt, + ) + if err != nil { + return fmt.Errorf("failed to create otp: %w", err) + } + return nil +} + +// Update updates an existing OTP +func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error { + query := ` + UPDATE otps + SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ? + WHERE id = ? AND user_id = ? + ` + otp.UpdatedAt = time.Now() + + result, err := r.db.ExecContext( + ctx, + query, + otp.Name, + otp.Issuer, + otp.Algorithm, + otp.Digits, + otp.Period, + otp.UpdatedAt, + otp.ID, + otp.UserID, + ) + if err != nil { + return fmt.Errorf("failed to update otp: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get affected rows: %w", err) + } + + if rows == 0 { + return fmt.Errorf("otp not found or not owned by user") + } + + return nil +} + +// Delete deletes an OTP +func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error { + query := `DELETE FROM otps WHERE id = ? AND user_id = ?` + result, err := r.db.ExecContext(ctx, query, id, userID) + if err != nil { + return fmt.Errorf("failed to delete otp: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get affected rows: %w", err) + } + + if rows == 0 { + return fmt.Errorf("otp not found or not owned by user") + } + + return nil +} + +// CountByUserID counts the number of OTPs for a user +func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) { + var count int + query := `SELECT COUNT(*) FROM otps WHERE user_id = ?` + err := r.db.GetContext(ctx, &count, query, userID) + if err != nil { + return 0, fmt.Errorf("failed to count otps: %w", err) + } + return count, nil +} + +// Transaction executes a function within a transaction +func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error { + tx, err := r.db.BeginTxx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } + }() + + if err := fn(tx); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr) + } + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..f12399b --- /dev/null +++ b/models/user.go @@ -0,0 +1,114 @@ +package models + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/jmoiron/sqlx" +) + +// User represents a user in the system +type User struct { + ID string `db:"id" json:"id"` + OpenID string `db:"openid" json:"openid"` + SessionKey string `db:"session_key" json:"-"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +// UserRepository handles user data operations +type UserRepository struct { + db *sqlx.DB +} + +// NewUserRepository creates a new UserRepository +func NewUserRepository(db *sqlx.DB) *UserRepository { + return &UserRepository{db: db} +} + +// FindByID finds a user by ID +func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) { + var user User + query := `SELECT * FROM users WHERE id = ?` + err := r.db.GetContext(ctx, &user, query, id) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("user not found: %w", err) + } + return nil, fmt.Errorf("failed to find user: %w", err) + } + return &user, nil +} + +// FindByOpenID finds a user by OpenID +func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) { + var user User + query := `SELECT * FROM users WHERE openid = ?` + err := r.db.GetContext(ctx, &user, query, openID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil // User not found, but not an error + } + return nil, fmt.Errorf("failed to find user: %w", err) + } + return &user, nil +} + +// Create creates a new user +func (r *UserRepository) Create(ctx context.Context, user *User) error { + query := ` + INSERT INTO users (id, openid, session_key, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + ` + now := time.Now() + user.CreatedAt = now + user.UpdatedAt = now + + _, err := r.db.ExecContext( + ctx, + query, + user.ID, + user.OpenID, + user.SessionKey, + user.CreatedAt, + user.UpdatedAt, + ) + if err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + return nil +} + +// Update updates an existing user +func (r *UserRepository) Update(ctx context.Context, user *User) error { + query := ` + UPDATE users + SET session_key = ?, updated_at = ? + WHERE id = ? + ` + user.UpdatedAt = time.Now() + + _, err := r.db.ExecContext( + ctx, + query, + user.SessionKey, + user.UpdatedAt, + user.ID, + ) + if err != nil { + return fmt.Errorf("failed to update user: %w", err) + } + return nil +} + +// Delete deletes a user +func (r *UserRepository) Delete(ctx context.Context, id string) error { + query := `DELETE FROM users WHERE id = ?` + _, err := r.db.ExecContext(ctx, query, id) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + return nil +} diff --git a/security/security.go b/security/security.go new file mode 100644 index 0000000..7b69b81 --- /dev/null +++ b/security/security.go @@ -0,0 +1,332 @@ +package security + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "net/http" + "strings" + "time" + + "golang.org/x/crypto/argon2" +) + +// SecurityService provides security functionality +type SecurityService struct { + config *Config +} + +// Config represents security configuration +type Config struct { + // CSRF protection + CSRFTokenLength int + CSRFTokenExpiry time.Duration + CSRFCookieName string + CSRFHeaderName string + CSRFCookieSecure bool + CSRFCookieHTTPOnly bool + CSRFCookieSameSite http.SameSite + + // Rate limiting + RateLimitRequests int + RateLimitWindow time.Duration + + // Password hashing + Argon2Time uint32 + Argon2Memory uint32 + Argon2Threads uint8 + Argon2KeyLen uint32 + Argon2SaltLen uint32 +} + +// DefaultConfig returns the default security configuration +func DefaultConfig() *Config { + return &Config{ + // CSRF protection + CSRFTokenLength: 32, + CSRFTokenExpiry: 24 * time.Hour, + CSRFCookieName: "csrf_token", + CSRFHeaderName: "X-CSRF-Token", + CSRFCookieSecure: true, + CSRFCookieHTTPOnly: true, + CSRFCookieSameSite: http.SameSiteStrictMode, + + // Rate limiting + RateLimitRequests: 100, + RateLimitWindow: time.Minute, + + // Password hashing + Argon2Time: 1, + Argon2Memory: 64 * 1024, + Argon2Threads: 4, + Argon2KeyLen: 32, + Argon2SaltLen: 16, + } +} + +// NewSecurityService creates a new SecurityService +func NewSecurityService(config *Config) *SecurityService { + if config == nil { + config = DefaultConfig() + } + return &SecurityService{ + config: config, + } +} + +// GenerateCSRFToken generates a CSRF token +func (s *SecurityService) GenerateCSRFToken() (string, error) { + // Generate random bytes + bytes := make([]byte, s.config.CSRFTokenLength) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Encode as base64 + token := base64.StdEncoding.EncodeToString(bytes) + return token, nil +} + +// SetCSRFCookie sets a CSRF cookie +func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) { + http.SetCookie(w, &http.Cookie{ + Name: s.config.CSRFCookieName, + Value: token, + Path: "/", + Expires: time.Now().Add(s.config.CSRFTokenExpiry), + Secure: s.config.CSRFCookieSecure, + HttpOnly: s.config.CSRFCookieHTTPOnly, + SameSite: s.config.CSRFCookieSameSite, + }) +} + +// ValidateCSRFToken validates a CSRF token +func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool { + // Get token from cookie + cookie, err := r.Cookie(s.config.CSRFCookieName) + if err != nil { + return false + } + cookieToken := cookie.Value + + // Get token from header + headerToken := r.Header.Get(s.config.CSRFHeaderName) + + // Compare tokens + return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1 +} + +// CSRFMiddleware creates a middleware that validates CSRF tokens +func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip for GET, HEAD, OPTIONS, TRACE + if r.Method == http.MethodGet || + r.Method == http.MethodHead || + r.Method == http.MethodOptions || + r.Method == http.MethodTrace { + next.ServeHTTP(w, r) + return + } + + // Validate CSRF token + if !s.ValidateCSRFToken(r) { + http.Error(w, "Invalid CSRF token", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} + +// RateLimiter represents a rate limiter +type RateLimiter struct { + requests map[string][]time.Time + config *Config +} + +// NewRateLimiter creates a new RateLimiter +func NewRateLimiter(config *Config) *RateLimiter { + return &RateLimiter{ + requests: make(map[string][]time.Time), + config: config, + } +} + +// Allow checks if a request is allowed +func (r *RateLimiter) Allow(key string) bool { + now := time.Now() + windowStart := now.Add(-r.config.RateLimitWindow) + + // Get requests for key + requests := r.requests[key] + + // Filter out old requests + var newRequests []time.Time + for _, t := range requests { + if t.After(windowStart) { + newRequests = append(newRequests, t) + } + } + + // Check if rate limit is exceeded + if len(newRequests) >= r.config.RateLimitRequests { + return false + } + + // Add current request + newRequests = append(newRequests, now) + r.requests[key] = newRequests + + return true +} + +// RateLimitMiddleware creates a middleware that limits request rate +func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get client IP + ip := getClientIP(r) + + // Check if request is allowed + if !limiter.Allow(ip) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// getClientIP gets the client IP address +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header + xForwardedFor := r.Header.Get("X-Forwarded-For") + if xForwardedFor != "" { + // X-Forwarded-For can contain multiple IPs, use the first one + ips := strings.Split(xForwardedFor, ",") + return strings.TrimSpace(ips[0]) + } + + // Check X-Real-IP header + xRealIP := r.Header.Get("X-Real-IP") + if xRealIP != "" { + return xRealIP + } + + // Use RemoteAddr + return r.RemoteAddr +} + +// HashPassword hashes a password using Argon2 +func (s *SecurityService) HashPassword(password string) (string, error) { + // Generate salt + salt := make([]byte, s.config.Argon2SaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("failed to generate salt: %w", err) + } + + // Hash password + hash := argon2.IDKey( + []byte(password), + salt, + s.config.Argon2Time, + s.config.Argon2Memory, + s.config.Argon2Threads, + s.config.Argon2KeyLen, + ) + + // Encode as base64 + saltBase64 := base64.StdEncoding.EncodeToString(salt) + hashBase64 := base64.StdEncoding.EncodeToString(hash) + + // Format as $argon2id$v=19$m=65536,t=1,p=4$$ + return fmt.Sprintf( + "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", + s.config.Argon2Memory, + s.config.Argon2Time, + s.config.Argon2Threads, + saltBase64, + hashBase64, + ), nil +} + +// VerifyPassword verifies a password against a hash +func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) { + // Parse encoded hash + parts := strings.Split(encodedHash, "$") + if len(parts) != 6 { + return false, fmt.Errorf("invalid hash format") + } + + // Extract parameters + if parts[1] != "argon2id" { + return false, fmt.Errorf("unsupported hash algorithm") + } + + var memory uint32 + var time uint32 + var threads uint8 + _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) + if err != nil { + return false, fmt.Errorf("failed to parse hash parameters: %w", err) + } + + // Decode salt and hash + salt, err := base64.StdEncoding.DecodeString(parts[4]) + if err != nil { + return false, fmt.Errorf("failed to decode salt: %w", err) + } + + hash, err := base64.StdEncoding.DecodeString(parts[5]) + if err != nil { + return false, fmt.Errorf("failed to decode hash: %w", err) + } + + // Hash password with same parameters + newHash := argon2.IDKey( + []byte(password), + salt, + time, + memory, + threads, + uint32(len(hash)), + ) + + // Compare hashes + return subtle.ConstantTimeCompare(hash, newHash) == 1, nil +} + +// SecureHeadersMiddleware adds security headers to responses +func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add security headers + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-XSS-Protection", "1; mode=block") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + w.Header().Set("Content-Security-Policy", "default-src 'self'") + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + + next.ServeHTTP(w, r) + }) +} + +// contextKey is a type for context keys +type contextKey string + +// userIDKey is the key for user ID in context +const userIDKey = contextKey("user_id") + +// WithUserID adds a user ID to the context +func WithUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, userIDKey, userID) +} + +// GetUserID gets the user ID from the context +func GetUserID(ctx context.Context) (string, bool) { + userID, ok := ctx.Value(userIDKey).(string) + return userID, ok +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..b1cadf4 --- /dev/null +++ b/server/server.go @@ -0,0 +1,172 @@ +package server + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "otpm/config" + "otpm/middleware" +) + +// Server represents the HTTP server +type Server struct { + server *http.Server + router *http.ServeMux + config *config.Config +} + +// New creates a new server +func New(cfg *config.Config) *Server { + router := http.NewServeMux() + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.Server.Port), + Handler: router, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + IdleTimeout: 120 * time.Second, + } + + return &Server{ + server: server, + router: router, + config: cfg, + } +} + +// Start starts the server +func (s *Server) Start() error { + // Apply global middleware in correct order with enhanced error handling + var handler http.Handler = s.router + + // Logger should be first to capture all request details + handler = middleware.Logger(handler) + + // CORS next to handle pre-flight requests + handler = middleware.CORS(handler) + + // Then Timeout to enforce request deadlines + handler = middleware.Timeout(s.config.Server.Timeout)(handler) + + // Recover should be outermost to catch any panics + handler = middleware.Recover(handler) + + s.server.Handler = handler + + // Log server configuration at startup + log.Printf("Server configuration:\n"+ + "Address: %s\n"+ + "Read Timeout: %v\n"+ + "Write Timeout: %v\n"+ + "Idle Timeout: %v\n"+ + "Request Timeout: %v", + s.server.Addr, + s.server.ReadTimeout, + s.server.WriteTimeout, + s.server.IdleTimeout, + s.config.Server.Timeout, + ) + + // Start server in a goroutine + serverErr := make(chan error, 1) + go func() { + log.Printf("Server starting on %s", s.server.Addr) + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + serverErr <- fmt.Errorf("server error: %w", err) + } + }() + + // Wait for interrupt signal or server error + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + select { + case err := <-serverErr: + return err + case <-quit: + return s.Shutdown() + } +} + +// Shutdown gracefully stops the server +func (s *Server) Shutdown() error { + log.Println("Shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout) + defer cancel() + + if err := s.server.Shutdown(ctx); err != nil { + return fmt.Errorf("graceful shutdown failed: %w", err) + } + + log.Println("Server stopped gracefully") + return nil +} + +// Router returns the router +func (s *Server) Router() *http.ServeMux { + return s.router +} + +// RegisterRoutes registers all routes +func (s *Server) RegisterRoutes(routes map[string]http.Handler) { + for pattern, handler := range routes { + s.router.Handle(pattern, handler) + } +} + +// RegisterAuthRoutes registers routes that require authentication +func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) { + for pattern, handler := range routes { + // Apply authentication middleware + authHandler := middleware.Auth(s.config.JWT.Secret)(handler) + s.router.Handle(pattern, authHandler) + } +} + +// RegisterHealthCheck registers an enhanced health check endpoint +func (s *Server) RegisterHealthCheck() { + s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "status": "ok", + "timestamp": time.Now().Format(time.RFC3339), + "version": "1.0.0", // Hardcoded version instead of from config + "system": map[string]interface{}{ + "goroutines": runtime.NumGoroutine(), + "memory": getMemoryUsage(), + }, + } + + // Add database status if configured + if s.config.Database.DSN != "" { // Changed from URL to DSN to match config + dbStatus := "ok" + // Removed DB ping check since we don't have DB instance in config + response["database"] = dbStatus + } + + middleware.SuccessResponse(w, response) + }) +} + +// getMemoryUsage returns current memory usage in MB +func getMemoryUsage() map[string]interface{} { + var m runtime.MemStats + runtime.ReadMemStats(&m) + return map[string]interface{}{ + "alloc_mb": bToMb(m.Alloc), + "total_alloc_mb": bToMb(m.TotalAlloc), + "sys_mb": bToMb(m.Sys), + "num_gc": m.NumGC, + } +} + +func bToMb(b uint64) float64 { + return float64(b) / 1024 / 1024 +} diff --git a/services/auth.go b/services/auth.go new file mode 100644 index 0000000..9ad2a85 --- /dev/null +++ b/services/auth.go @@ -0,0 +1,230 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + + "otpm/config" + "otpm/models" +) + +// WeChatCode2SessionResponse represents the response from WeChat code2session API +type WeChatCode2SessionResponse struct { + OpenID string `json:"openid"` + SessionKey string `json:"session_key"` + UnionID string `json:"unionid"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +// AuthService handles authentication related operations +type AuthService struct { + config *config.Config + userRepo *models.UserRepository + httpClient *http.Client +} + +// NewAuthService creates a new AuthService +func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService { + return &AuthService{ + config: cfg, + userRepo: userRepo, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// LoginWithWeChatCode handles WeChat login +func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) { + start := time.Now() + + // Get OpenID and SessionKey from WeChat + sessionInfo, err := s.getWeChatSession(code) + if err != nil { + log.Printf("WeChat login failed for code %s: %v", maskCode(code), err) + return "", fmt.Errorf("failed to get WeChat session: %w", err) + } + log.Printf("WeChat session obtained for code %s (took %v)", + maskCode(code), time.Since(start)) + + // Find or create user + user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID) + if err != nil { + log.Printf("User lookup failed for OpenID %s: %v", + maskOpenID(sessionInfo.OpenID), err) + return "", fmt.Errorf("failed to find user: %w", err) + } + + if user == nil { + // Create new user + user = &models.User{ + ID: uuid.New().String(), + OpenID: sessionInfo.OpenID, + SessionKey: sessionInfo.SessionKey, + } + if err := s.userRepo.Create(ctx, user); err != nil { + log.Printf("User creation failed for OpenID %s: %v", + maskOpenID(sessionInfo.OpenID), err) + return "", fmt.Errorf("failed to create user: %w", err) + } + log.Printf("New user created with ID %s for OpenID %s", + user.ID, maskOpenID(sessionInfo.OpenID)) + } else { + // Update session key + user.SessionKey = sessionInfo.SessionKey + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("User update failed for ID %s: %v", user.ID, err) + return "", fmt.Errorf("failed to update user: %w", err) + } + log.Printf("User %s session key updated", user.ID) + } + + // Generate JWT token + token, err := s.generateToken(user) + if err != nil { + log.Printf("Token generation failed for user %s: %v", user.ID, err) + return "", fmt.Errorf("failed to generate token: %w", err) + } + + log.Printf("WeChat login completed for user %s (total time %v)", + user.ID, time.Since(start)) + return token, nil +} + +// maskCode masks sensitive parts of WeChat code for logging +func maskCode(code string) string { + if len(code) < 8 { + return "****" + } + return code[:2] + "****" + code[len(code)-2:] +} + +// maskOpenID masks sensitive parts of OpenID for logging +func maskOpenID(openID string) string { + if len(openID) < 8 { + return "****" + } + return openID[:2] + "****" + openID[len(openID)-2:] +} + +// getWeChatSession calls WeChat's code2session API +func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) { + url := fmt.Sprintf( + "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", + s.config.WeChat.AppID, + s.config.WeChat.AppSecret, + code, + ) + + resp, err := s.httpClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to call WeChat API: %w", err) + } + defer resp.Body.Close() + + var result WeChatCode2SessionResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode WeChat response: %w", err) + } + + if result.ErrCode != 0 { + return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg) + } + + return &result, nil +} + +// generateToken generates a JWT token for a user +func (s *AuthService) generateToken(user *models.User) (string, error) { + now := time.Now() + claims := jwt.MapClaims{ + "user_id": user.ID, + "exp": now.Add(s.config.JWT.ExpireDelta).Unix(), + "iat": now.Unix(), + "iss": s.config.JWT.Issuer, + "aud": s.config.JWT.Audience, + "token_id": uuid.New().String(), // Unique token ID for tracking + } + + // Use stronger signing method + token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) + signedToken, err := token.SignedString([]byte(s.config.JWT.Secret)) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + log.Printf("Token generated for user %s (expires at %v)", + user.ID, now.Add(s.config.JWT.ExpireDelta)) + return signedToken, nil +} + +// ValidateToken validates a JWT token with additional checks +func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(s.config.JWT.Secret), nil + }) + + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + switch { + case ve.Errors&jwt.ValidationErrorMalformed != 0: + return nil, fmt.Errorf("malformed token") + case ve.Errors&jwt.ValidationErrorExpired != 0: + return nil, fmt.Errorf("token expired") + case ve.Errors&jwt.ValidationErrorNotValidYet != 0: + return nil, fmt.Errorf("token not active yet") + default: + return nil, fmt.Errorf("token validation error: %w", err) + } + } + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + // Additional claims validation + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + // Check issuer + if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer { + return nil, fmt.Errorf("invalid token issuer") + } + // Check audience + if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience { + return nil, fmt.Errorf("invalid token audience") + } + } else { + return nil, fmt.Errorf("invalid token claims") + } + + return token, nil +} + +// GetUserFromToken gets user information from a JWT token +func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) { + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid token claims") + } + + userID, ok := claims["user_id"].(string) + if !ok { + return nil, fmt.Errorf("user_id not found in token") + } + + user, err := s.userRepo.FindByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to find user: %w", err) + } + + return user, nil +} diff --git a/services/otp.go b/services/otp.go new file mode 100644 index 0000000..c345bea --- /dev/null +++ b/services/otp.go @@ -0,0 +1,358 @@ +package services + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/base32" + "encoding/binary" + "fmt" + "hash" + "log" + "strings" + "time" + + "otpm/models" + + "github.com/google/uuid" +) + +// OTPService handles OTP related operations +type OTPService struct { + otpRepo *models.OTPRepository +} + +// NewOTPService creates a new OTPService +func NewOTPService(otpRepo *models.OTPRepository) *OTPService { + return &OTPService{ + otpRepo: otpRepo, + } +} + +// CreateOTP creates a new OTP with performance monitoring and logging +func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) { + start := time.Now() + + // Validate input + if err := s.validateOTPInput(input); err != nil { + log.Printf("OTP validation failed for user %s: %v", userID, err) + return nil, err + } + + // Clean and standardize secret + secret := cleanSecret(input.Secret) + + // Set defaults for optional fields + algorithm := strings.ToUpper(input.Algorithm) + if algorithm == "" { + algorithm = "SHA1" + } + + digits := input.Digits + if digits == 0 { + digits = 6 + } + + period := input.Period + if period == 0 { + period = 30 + } + + // Create OTP + otp := &models.OTP{ + ID: uuid.New().String(), + UserID: userID, + Name: input.Name, + Issuer: input.Issuer, + Secret: secret, + Algorithm: algorithm, + Digits: digits, + Period: period, + } + + if err := s.otpRepo.Create(ctx, otp); err != nil { + log.Printf("Failed to create OTP for user %s: %v", userID, err) + return nil, fmt.Errorf("failed to create OTP: %w", err) + } + + // Log successful creation (without exposing secret) + log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)", + otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period) + + return otp, nil +} + +// GetOTP gets an OTP by ID +func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) { + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + return nil, fmt.Errorf("failed to get OTP: %w", err) + } + return otp, nil +} + +// ListOTPs lists all OTPs for a user +func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) { + otps, err := s.otpRepo.FindAllByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to list OTPs: %w", err) + } + return otps, nil +} + +// UpdateOTP updates an OTP +func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) { + // Get existing OTP + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + return nil, fmt.Errorf("failed to get OTP: %w", err) + } + + // Update fields + if input.Name != "" { + otp.Name = input.Name + } + if input.Issuer != "" { + otp.Issuer = input.Issuer + } + if input.Algorithm != "" { + otp.Algorithm = strings.ToUpper(input.Algorithm) + } + if input.Digits > 0 { + otp.Digits = input.Digits + } + if input.Period > 0 { + otp.Period = input.Period + } + + // Validate updated OTP + if err := s.validateOTPInput(models.OTPParams{ + Name: otp.Name, + Issuer: otp.Issuer, + Secret: otp.Secret, + Algorithm: otp.Algorithm, + Digits: otp.Digits, + Period: otp.Period, + }); err != nil { + return nil, err + } + + if err := s.otpRepo.Update(ctx, otp); err != nil { + return nil, fmt.Errorf("failed to update OTP: %w", err) + } + + return otp, nil +} + +// DeleteOTP deletes an OTP +func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error { + if err := s.otpRepo.Delete(ctx, id, userID); err != nil { + return fmt.Errorf("failed to delete OTP: %w", err) + } + return nil +} + +// GenerateCode generates a TOTP code with enhanced logging and error handling +func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) { + start := time.Now() + + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err) + return "", 0, fmt.Errorf("failed to get OTP: %w", err) + } + + // Get current time step + now := time.Now().Unix() + timeStep := now / int64(otp.Period) + + // Generate code + code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits) + if err != nil { + log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err) + return "", 0, fmt.Errorf("failed to generate code: %w", err) + } + + // Calculate remaining seconds + remainingSeconds := otp.Period - int(now%int64(otp.Period)) + + // Log successful generation (without actual code) + log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)", + id, userID, time.Since(start), remainingSeconds) + + return code, remainingSeconds, nil +} + +// VerifyCode verifies a TOTP code with enhanced security and logging +func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) { + start := time.Now() + + // Basic input validation + if len(code) == 0 { + log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID) + return false, fmt.Errorf("code is required") + } + + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + log.Printf("Failed to find OTP %s for user %s during verification: %v", + id, userID, err) + return false, fmt.Errorf("failed to get OTP: %w", err) + } + + // Get current and adjacent time steps + now := time.Now().Unix() + timeSteps := []int64{ + (now - int64(otp.Period)) / int64(otp.Period), + now / int64(otp.Period), + (now + int64(otp.Period)) / int64(otp.Period), + } + + // Check code against all time steps + for _, ts := range timeSteps { + expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits) + if err != nil { + log.Printf("Code generation failed for time step %d: %v", ts, err) + continue + } + if expectedCode == code { + // Log successful verification + log.Printf("Code verified successfully for OTP %s (user %s) in %v", + id, userID, time.Since(start)) + return true, nil + } + } + + // Log failed verification attempt + log.Printf("Invalid code provided for OTP %s (user %s) in %v", + id, userID, time.Since(start)) + + return false, nil +} + +// validateOTPInput validates OTP input with detailed error messages +func (s *OTPService) validateOTPInput(input models.OTPParams) error { + if input.Name == "" { + return fmt.Errorf("name is required") + } + + if len(input.Name) > 100 { + return fmt.Errorf("name is too long (maximum 100 characters)") + } + + if input.Secret == "" { + return fmt.Errorf("secret is required") + } + + if !isValidBase32(input.Secret) { + return fmt.Errorf("invalid secret format: must be a valid base32 string") + } + + // Secret length check (after base32 decoding) + secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "=")) + if len(secretBytes) < 10 { + return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)") + } + + if input.Algorithm != "" { + if !isValidAlgorithm(input.Algorithm) { + return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm) + } + } + + if input.Digits != 0 { + if input.Digits < 6 || input.Digits > 8 { + return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits) + } + } + + if input.Period != 0 { + if input.Period < 30 || input.Period > 60 { + return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period) + } + } + + return nil +} + +// Helper functions + +func cleanSecret(secret string) string { + // Remove spaces and convert to upper case + secret = strings.TrimSpace(strings.ToUpper(secret)) + // Remove any padding characters + return strings.TrimRight(secret, "=") +} + +func isValidBase32(s string) bool { + // Try to decode the secret + _, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "=")) + return err == nil +} + +func isValidAlgorithm(algorithm string) bool { + switch strings.ToUpper(algorithm) { + case "SHA1", "SHA256", "SHA512": + return true + default: + return false + } +} + +func getHasher(algorithm string, key []byte) (hash.Hash, error) { + switch strings.ToUpper(algorithm) { + case "SHA1": + return hmac.New(sha1.New, key), nil + case "SHA256": + return hmac.New(sha256.New, key), nil + case "SHA512": + return hmac.New(sha512.New, key), nil + default: + return nil, fmt.Errorf("unsupported algorithm: %s", algorithm) + } +} + +func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) { + // Decode secret + secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "=")) + if err != nil { + return "", fmt.Errorf("invalid secret: %w", err) + } + + // Get initialized HMAC hasher with secret + hasher, err := getHasher(algorithm, secretBytes) + if err != nil { + return "", err + } + + // Convert time step to bytes + timeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timeBytes, uint64(timeStep)) + + // Calculate HMAC + hasher.Write(timeBytes) + hash := hasher.Sum(nil) + + // Get offset + offset := hash[len(hash)-1] & 0xf + + // Generate 4-byte code + code := binary.BigEndian.Uint32(hash[offset : offset+4]) + code = code & 0x7fffffff + + // Get the specified number of digits + code = code % uint32(pow10(digits)) + + // Format code with leading zeros + return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil +} + +func pow10(n int) uint32 { + result := uint32(1) + for i := 0; i < n; i++ { + result *= 10 + } + return result +} diff --git a/validator/validator.go b/validator/validator.go new file mode 100644 index 0000000..6fa8f1f --- /dev/null +++ b/validator/validator.go @@ -0,0 +1,159 @@ +package validator + +import ( + "encoding/json" + "fmt" + "net/http" + "reflect" + "regexp" + "strings" + + "github.com/go-playground/validator/v10" +) + +var ( + validate *validator.Validate + // 自定义验证规则 + customValidations = map[string]validator.Func{ + "otpsecret": validateOTPSecret, + "password": validatePassword, + } +) + +func init() { + validate = validator.New() + + // 注册自定义验证规则 + for tag, fn := range customValidations { + if err := validate.RegisterValidation(tag, fn); err != nil { + panic(fmt.Sprintf("failed to register validation %s: %v", tag, err)) + } + } + + // 使用json tag作为字段名 + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) +} + +// ValidateRequest validates a request body against a struct +func ValidateRequest(r *http.Request, v interface{}) error { + if err := json.NewDecoder(r.Body).Decode(v); err != nil { + return fmt.Errorf("invalid request body: %w", err) + } + + if err := validate.Struct(v); err != nil { + if validationErrors, ok := err.(validator.ValidationErrors); ok { + return NewValidationError(validationErrors) + } + return fmt.Errorf("validation error: %w", err) + } + + return nil +} + +// ValidationError represents a validation error +type ValidationError struct { + Fields map[string]string `json:"fields"` +} + +// Error implements the error interface +func (e *ValidationError) Error() string { + var errors []string + for field, msg := range e.Fields { + errors = append(errors, fmt.Sprintf("%s: %s", field, msg)) + } + return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; ")) +} + +// NewValidationError creates a new ValidationError from validator.ValidationErrors +func NewValidationError(errors validator.ValidationErrors) *ValidationError { + fields := make(map[string]string) + for _, err := range errors { + fields[err.Field()] = getErrorMessage(err) + } + return &ValidationError{Fields: fields} +} + +// getErrorMessage returns a human-readable error message for a validation error +func getErrorMessage(err validator.FieldError) string { + switch err.Tag() { + case "required": + return "This field is required" + case "email": + return "Invalid email address" + case "min": + return fmt.Sprintf("Must be at least %s characters long", err.Param()) + case "max": + return fmt.Sprintf("Must be at most %s characters long", err.Param()) + case "otpsecret": + return "Invalid OTP secret format" + case "password": + return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character" + default: + return fmt.Sprintf("Failed validation on tag: %s", err.Tag()) + } +} + +// Custom validation functions + +// validateOTPSecret validates an OTP secret +func validateOTPSecret(fl validator.FieldLevel) bool { + secret := fl.Field().String() + // OTP secret should be base32 encoded + matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret) + return matched +} + +// validatePassword validates a password +func validatePassword(fl validator.FieldLevel) bool { + password := fl.Field().String() + // At least 8 characters long + if len(password) < 8 { + return false + } + + var ( + hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password) + hasLower = regexp.MustCompile(`[a-z]`).MatchString(password) + hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password) + hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) + ) + + return hasUpper && hasLower && hasNumber && hasSpecial +} + +// Request validation structs + +// LoginRequest represents a login request +type LoginRequest struct { + Code string `json:"code" validate:"required"` +} + +// CreateOTPRequest represents a request to create an OTP +type CreateOTPRequest struct { + Name string `json:"name" validate:"required,min=1,max=100"` + Issuer string `json:"issuer" validate:"required,min=1,max=100"` + Secret string `json:"secret" validate:"required,otpsecret"` + Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" validate:"required,oneof=6 8"` + Period int `json:"period" validate:"required,oneof=30 60"` +} + +// UpdateOTPRequest represents a request to update an OTP +type UpdateOTPRequest struct { + Name string `json:"name" validate:"omitempty,min=1,max=100"` + Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"` + Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" validate:"omitempty,oneof=6 8"` + Period int `json:"period" validate:"omitempty,oneof=30 60"` +} + +// VerifyOTPRequest represents a request to verify an OTP code +type VerifyOTPRequest struct { + Code string `json:"code" validate:"required,len=6|len=8,numeric"` +}