diff --git a/api/response.go b/api/response.go new file mode 100644 index 0000000..8e24daf --- /dev/null +++ b/api/response.go @@ -0,0 +1,149 @@ +package api + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" +) + +// Common error codes +const ( + CodeSuccess = 0 + CodeInvalidParams = 400 + CodeUnauthorized = 401 + CodeForbidden = 403 + CodeNotFound = 404 + CodeInternalError = 500 + CodeServiceUnavail = 503 +) + +// Error represents an API error +type Error struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// Error implements the error interface +func (e *Error) Error() string { + return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message) +} + +// NewError creates a new API error +func NewError(code int, message string) *Error { + return &Error{ + Code: code, + Message: message, + } +} + +// Response represents a standard API response +type Response struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ResponseWriter wraps common response writing functions +type ResponseWriter struct { + http.ResponseWriter +} + +// NewResponseWriter creates a new ResponseWriter +func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { + return &ResponseWriter{w} +} + +// WriteJSON writes a JSON response +func (w *ResponseWriter) WriteJSON(code int, data interface{}) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + return json.NewEncoder(w).Encode(data) +} + +// WriteSuccess writes a success response +func (w *ResponseWriter) WriteSuccess(data interface{}) error { + return w.WriteJSON(http.StatusOK, Response{ + Code: CodeSuccess, + Message: "success", + Data: data, + }) +} + +// WriteError writes an error response +func (w *ResponseWriter) WriteError(err error) error { + var apiErr *Error + if errors.As(err, &apiErr) { + return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{ + Code: apiErr.Code, + Message: apiErr.Message, + }) + } + + // Handle unknown errors + return w.WriteJSON(http.StatusInternalServerError, Response{ + Code: CodeInternalError, + Message: "Internal Server Error", + }) +} + +// WriteErrorWithCode writes an error response with a specific code +func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error { + return w.WriteJSON(getHTTPStatus(code), Response{ + Code: code, + Message: message, + }) +} + +// getHTTPStatus maps API error codes to HTTP status codes +func getHTTPStatus(code int) int { + switch code { + case CodeSuccess: + return http.StatusOK + case CodeInvalidParams: + return http.StatusBadRequest + case CodeUnauthorized: + return http.StatusUnauthorized + case CodeForbidden: + return http.StatusForbidden + case CodeNotFound: + return http.StatusNotFound + case CodeServiceUnavail: + return http.StatusServiceUnavailable + default: + return http.StatusInternalServerError + } +} + +// Common errors +var ( + ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters") + ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized") + ErrForbidden = NewError(CodeForbidden, "Forbidden") + ErrNotFound = NewError(CodeNotFound, "Resource not found") + ErrInternalError = NewError(CodeInternalError, "Internal server error") + ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable") +) + +// ValidationError creates an error for invalid parameters +func ValidationError(message string) *Error { + return NewError(CodeInvalidParams, message) +} + +// NotFoundError creates an error for not found resources +func NotFoundError(resource string) *Error { + return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource)) +} + +// ForbiddenError creates an error for forbidden actions +func ForbiddenError(message string) *Error { + return NewError(CodeForbidden, message) +} + +// InternalError creates an error for internal server errors +func InternalError(err error) *Error { + if err == nil { + return ErrInternalError + } + return NewError(CodeInternalError, err.Error()) +} diff --git a/api/validator.go b/api/validator.go new file mode 100644 index 0000000..a18a552 --- /dev/null +++ b/api/validator.go @@ -0,0 +1,152 @@ +package api + +import ( + "regexp" + "strings" + + "github.com/go-playground/validator/v10" +) + +// Validate is a global validator instance +var Validate = validator.New() + +// RegisterCustomValidations registers custom validation functions +func RegisterCustomValidations() { + // Register custom validation for issuer + Validate.RegisterValidation("issuer", validateIssuer) + + // Register custom validation for XSS prevention + Validate.RegisterValidation("no_xss", validateNoXSS) + + // Register custom validation for OTP secret + Validate.RegisterValidation("otpsecret", validateOTPSecret) +} + +// validateOTPSecret validates that the OTP secret is in valid base32 format +func validateOTPSecret(fl validator.FieldLevel) bool { + secret := fl.Field().String() + + // Check if the secret is not empty + if secret == "" { + return false + } + + // Check if the secret is in base32 format (A-Z, 2-7) + base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`) + if !base32Regex.MatchString(secret) { + return false + } + + // Check if the length is valid (must be at least 16 characters) + if len(secret) < 16 || len(secret) > 128 { + return false + } + + return true +} + +// validateIssuer validates that the issuer field contains only allowed characters +func validateIssuer(fl validator.FieldLevel) bool { + issuer := fl.Field().String() + + // Empty issuer is valid (since it's optional) + if issuer == "" { + return true + } + + // Allow alphanumeric characters, spaces, and common punctuation + issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api + +import ( + "regexp" + "strings" + + "github.com/go-playground/validator/v10" +) + +// Validate is a global validator instance +var Validate = validator.New() + +// RegisterCustomValidations registers custom validation functions +func RegisterCustomValidations() { + // Register custom validation for issuer + Validate.RegisterValidation("issuer", validateIssuer) + + // Register custom validation for XSS prevention + Validate.RegisterValidation("no_xss", validateNoXSS) + + // Register custom validation for OTP secret + Validate.RegisterValidation("otpsecret", validateOTPSecret) +} + +// validateOTPSecret validates that the OTP secret is in valid base32 format +func validateOTPSecret(fl validator.FieldLevel) bool { + secret := fl.Field().String() + + // Check if the secret is not empty + if secret == "" { + return false + } + + // Check if the secret is in base32 format (A-Z, 2-7) + base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`) + if !base32Regex.MatchString(secret) { + return false + } + + // Check if the length is valid (must be at least 16 characters) + if len(secret) < 16 || len(secret) > 128 { + return false + } + + return true +} + +) + if !issuerRegex.MatchString(issuer) { + return false + } + + // Check length + if len(issuer) > 100 { + return false + } + + return true +} + +// validateNoXSS validates that the field doesn't contain potential XSS payloads +func validateNoXSS(fl validator.FieldLevel) bool { + value := fl.Field().String() + + // Check for HTML encoding + if strings.Contains(value, "&#") || + strings.Contains(value, "<") || + strings.Contains(value, ">") { + return false + } + + // Check for common XSS patterns + suspiciousPatterns := []*regexp.Regexp{ + regexp.MustCompile(`(?i)]*>.*?`), + regexp.MustCompile(`(?i)javascript:`), + regexp.MustCompile(`(?i)data:text/html`), + regexp.MustCompile(`(?i)on\w+\s*=`), + regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), + regexp.MustCompile(`(?i)<\s*iframe`), + regexp.MustCompile(`(?i)<\s*object`), + regexp.MustCompile(`(?i)<\s*embed`), + regexp.MustCompile(`(?i)<\s*style`), + regexp.MustCompile(`(?i)<\s*form`), + regexp.MustCompile(`(?i)<\s*applet`), + regexp.MustCompile(`(?i)<\s*meta`), + } + + for _, pattern := range suspiciousPatterns { + if pattern.MatchString(value) { + return false + } + } + + return true +} \ No newline at end of file diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..08d3041 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,206 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" +) + +// Item represents a cache item +type Item struct { + Value []byte + Expiration int64 +} + +// Expired returns true if the item has expired +func (item Item) Expired() bool { + if item.Expiration == 0 { + return false + } + return time.Now().UnixNano() > item.Expiration +} + +// Cache represents an in-memory cache +type Cache struct { + items map[string]Item + mu sync.RWMutex + defaultExpiration time.Duration + cleanupInterval time.Duration + stopCleanup chan bool +} + +// New creates a new cache with the given default expiration and cleanup interval +func New(defaultExpiration, cleanupInterval time.Duration) *Cache { + cache := &Cache{ + items: make(map[string]Item), + defaultExpiration: defaultExpiration, + cleanupInterval: cleanupInterval, + stopCleanup: make(chan bool), + } + + // Start cleanup goroutine if cleanup interval > 0 + if cleanupInterval > 0 { + go cache.startCleanup() + } + + return cache +} + +// startCleanup starts the cleanup process +func (c *Cache) startCleanup() { + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.DeleteExpired() + case <-c.stopCleanup: + return + } + } +} + +// Set adds an item to the cache with the given key and expiration +func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error { + // Convert value to bytes + var valueBytes []byte + var err error + + switch v := value.(type) { + case []byte: + valueBytes = v + case string: + valueBytes = []byte(v) + default: + valueBytes, err = json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal value: %w", err) + } + } + + // Calculate expiration + var exp int64 + if expiration == 0 { + if c.defaultExpiration > 0 { + exp = time.Now().Add(c.defaultExpiration).UnixNano() + } + } else if expiration > 0 { + exp = time.Now().Add(expiration).UnixNano() + } + + c.mu.Lock() + c.items[key] = Item{ + Value: valueBytes, + Expiration: exp, + } + c.mu.Unlock() + + return nil +} + +// Get gets an item from the cache +func (c *Cache) Get(key string, value interface{}) (bool, error) { + c.mu.RLock() + item, found := c.items[key] + c.mu.RUnlock() + + if !found { + return false, nil + } + + // Check if item has expired + if item.Expired() { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() + return false, nil + } + + // Unmarshal value + switch v := value.(type) { + case *[]byte: + *v = item.Value + case *string: + *v = string(item.Value) + default: + if err := json.Unmarshal(item.Value, value); err != nil { + return true, fmt.Errorf("failed to unmarshal value: %w", err) + } + } + + return true, nil +} + +// Delete deletes an item from the cache +func (c *Cache) Delete(key string) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() +} + +// DeleteExpired deletes all expired items from the cache +func (c *Cache) DeleteExpired() { + now := time.Now().UnixNano() + + c.mu.Lock() + for k, v := range c.items { + if v.Expiration > 0 && now > v.Expiration { + delete(c.items, k) + } + } + c.mu.Unlock() +} + +// Clear deletes all items from the cache +func (c *Cache) Clear() { + c.mu.Lock() + c.items = make(map[string]Item) + c.mu.Unlock() +} + +// Close stops the cleanup goroutine +func (c *Cache) Close() { + if c.cleanupInterval > 0 { + c.stopCleanup <- true + } +} + +// CacheService provides caching functionality +type CacheService struct { + cache *Cache +} + +// NewCacheService creates a new CacheService +func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService { + return &CacheService{ + cache: New(defaultExpiration, cleanupInterval), + } +} + +// Set adds an item to the cache +func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error { + return s.cache.Set(key, value, expiration) +} + +// Get gets an item from the cache +func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) { + return s.cache.Get(key, value) +} + +// Delete deletes an item from the cache +func (s *CacheService) Delete(ctx context.Context, key string) { + s.cache.Delete(key) +} + +// Clear deletes all items from the cache +func (s *CacheService) Clear(ctx context.Context) { + s.cache.Clear() +} + +// Close closes the cache +func (s *CacheService) Close() { + s.cache.Close() +} diff --git a/cmd/root.go b/cmd/root.go index 3752544..8821b91 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,83 +1,144 @@ package cmd import ( + "context" "fmt" "log" "net/http" "os" + "os/signal" + "syscall" + + "github.com/spf13/viper" + + "otpm/api" + "otpm/config" "otpm/database" "otpm/handlers" - "otpm/utils" - - "github.com/jmoiron/sqlx" - "github.com/julienschmidt/httprouter" - "github.com/spf13/cobra" - "github.com/spf13/viper" + "otpm/models" + "otpm/server" + "otpm/services" ) -var rootCmd = &cobra.Command{ - Use: "otpm", - Short: "otp backend for microapp on wechat", - Run: func(cmd *cobra.Command, args []string) { - startApp() - }, -} - -func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } -} - func init() { - cobra.OnInitialize(initConfig) + // Set config file with multi-environment support + viper.SetConfigName("config") + viper.SetConfigType("yaml") + viper.AddConfigPath(".") - rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)") - rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)") - rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string") - rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on") - - viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver")) - viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn")) - viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) -} - -func initConfig() { - if cfgFile := viper.GetString("config"); cfgFile != "" { - viper.SetConfigFile(cfgFile) - } else { - viper.AddConfigPath(".") - viper.SetConfigName("config") - viper.SetConfigType("yaml") + // Set environment specific config (e.g. config.production.yaml) + env := os.Getenv("OTPM_ENV") + if env != "" { + viper.SetConfigName(fmt.Sprintf("config.%s", env)) } - if err := viper.ReadInConfig(); err != nil { - log.Fatalf("Error reading config file: %v", err) - } + // Set default values + viper.SetDefault("server.port", "8080") + viper.SetDefault("server.timeout.read", "15s") + viper.SetDefault("server.timeout.write", "15s") + viper.SetDefault("server.timeout.idle", "60s") + viper.SetDefault("database.max_open_conns", 25) + viper.SetDefault("database.max_idle_conns", 5) + viper.SetDefault("database.conn_max_lifetime", "5m") + + // Set environment variable prefix + viper.SetEnvPrefix("OTPM") + viper.AutomaticEnv() + + // Bind environment variables + viper.BindEnv("database.url", "OTPM_DB_URL") + viper.BindEnv("database.password", "OTPM_DB_PASSWORD") } -func initApp(db *sqlx.DB) { - if err := database.MigrateDB(db); err != nil { - log.Fatalf("Error migrating the database: %v", err) - } -} - -func startApp() { - port := viper.GetInt("port") - db, err := database.InitDB() +// Execute is the entry point for the application +func Execute() error { + // Load configuration + cfg, err := config.LoadConfig() if err != nil { - log.Fatalf("Error connecting to the database: %v", err) + return fmt.Errorf("failed to load config: %w", err) + } + + // Create context with cancellation + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Printf("Received signal: %v", sig) + cancel() + }() + + // Initialize database + db, err := database.New(&cfg.Database) + if err != nil { + return fmt.Errorf("failed to initialize database: %w", err) } defer db.Close() - initApp(db) - handler := &handlers.Handler{DB: db} - router := httprouter.New() - router.POST("/login", utils.AdaptHandler(handler.Login)) - router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp)) - router.GET("/get", utils.AdaptHandler(handler.GetOtp)) + // Run database migrations + if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil { + return fmt.Errorf("failed to run migrations: %w", err) + } - log.Println("Starting server on :8080") - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), router)) + // Initialize repositories + userRepo := models.NewUserRepository(db.DB) + otpRepo := models.NewOTPRepository(db.DB) + + // Initialize services + authService := services.NewAuthService(cfg, userRepo) + otpService := services.NewOTPService(otpRepo) + + // Register custom validations + api.RegisterCustomValidations() + + // Initialize handlers + authHandler := handlers.NewAuthHandler(authService) + otpHandler := handlers.NewOTPHandler(otpService) + + // Create and configure server + srv := server.New(cfg) + + // Register health check endpoint + srv.RegisterHealthCheck() + + // Register public routes with type conversion + authRoutes := make(map[string]http.Handler) + for path, handler := range authHandler.Routes() { + authRoutes[path] = http.HandlerFunc(handler) + } + srv.RegisterRoutes(authRoutes) + + // Register authenticated routes with type conversion + otpRoutes := make(map[string]http.Handler) + for path, handler := range otpHandler.Routes() { + otpRoutes[path] = http.HandlerFunc(handler) + } + srv.RegisterAuthRoutes(otpRoutes) + + // Start server in goroutine + serverErr := make(chan error, 1) + go func() { + log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port) + if err := srv.Start(); err != nil { + serverErr <- fmt.Errorf("server error: %w", err) + } + }() + + // Wait for shutdown signal or server error + select { + case err := <-serverErr: + return err + case <-ctx.Done(): + // Graceful shutdown with timeout + log.Println("Shutting down server...") + if err := srv.Shutdown(); err != nil { + return fmt.Errorf("server shutdown error: %w", err) + } + log.Println("Server stopped gracefully") + } + + return nil } diff --git a/config.yaml b/config.yaml index f23f8fb..da2013f 100644 --- a/config.yaml +++ b/config.yaml @@ -1,8 +1,23 @@ +server: + port: 8080 + read_timeout: 15s + write_timeout: 15s + shutdown_timeout: 5s + database: - driver: sqlite + driver: sqlite3 dsn: otpm.sqlite -port: 8080 + max_open_conns: 25 + max_idle_conns: 25 + max_lifetime: 5m + skip_migration: false + +jwt: + secret: "your-jwt-secret-key-change-this-in-production" + expire_delta: 24h + refresh_delta: 168h + signing_method: HS256 wechat: - appid: "wx57d1033974eb5250" - secret: "be494c2a81df685a40b9a74e1736b15d" + app_id: "your-wechat-app-id" + app_secret: "your-wechat-app-secret" \ No newline at end of file diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..4cad24d --- /dev/null +++ b/config/config.go @@ -0,0 +1,128 @@ +package config + +import ( + "fmt" + "time" + + "github.com/spf13/viper" +) + +// Config holds all configuration for the application +type Config struct { + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + JWT JWTConfig `mapstructure:"jwt"` + WeChat WeChatConfig `mapstructure:"wechat"` +} + +// ServerConfig holds all server related configuration +type ServerConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + ReadTimeout time.Duration `mapstructure:"read_timeout"` + WriteTimeout time.Duration `mapstructure:"write_timeout"` + ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"` + Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout +} + +// DatabaseConfig holds all database related configuration +type DatabaseConfig struct { + Driver string `mapstructure:"driver"` + DSN string `mapstructure:"dsn"` + MaxOpenConns int `mapstructure:"max_open_conns"` + MaxIdleConns int `mapstructure:"max_idle_conns"` + MaxLifetime time.Duration `mapstructure:"max_lifetime"` + SkipMigration bool `mapstructure:"skip_migration"` +} + +// JWTConfig holds all JWT related configuration +type JWTConfig struct { + Secret string `mapstructure:"secret"` + ExpireDelta time.Duration `mapstructure:"expire_delta"` + RefreshDelta time.Duration `mapstructure:"refresh_delta"` + SigningMethod string `mapstructure:"signing_method"` + Issuer string `mapstructure:"issuer"` + Audience string `mapstructure:"audience"` +} + +// WeChatConfig holds all WeChat related configuration +type WeChatConfig struct { + AppID string `mapstructure:"app_id"` + AppSecret string `mapstructure:"app_secret"` +} + +// LoadConfig loads the configuration from file and environment variables +func LoadConfig() (*Config, error) { + // Set default values + setDefaults() + + // Read config file + if err := viper.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var config Config + if err := viper.Unmarshal(&config); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Validate config + if err := validateConfig(&config); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + return &config, nil +} + +// setDefaults sets default values for configuration +func setDefaults() { + // Server defaults + viper.SetDefault("server.port", 8080) + viper.SetDefault("server.read_timeout", "15s") + viper.SetDefault("server.write_timeout", "15s") + viper.SetDefault("server.shutdown_timeout", "5s") + viper.SetDefault("server.timeout", "30s") // Default request processing timeout + + // Database defaults + viper.SetDefault("database.driver", "sqlite3") + viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection + viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection + viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling + viper.SetDefault("database.skip_migration", false) + + // JWT defaults + viper.SetDefault("jwt.expire_delta", "24h") + viper.SetDefault("jwt.refresh_delta", "168h") // 7 days + viper.SetDefault("jwt.signing_method", "HS256") + viper.SetDefault("jwt.issuer", "otpm") + viper.SetDefault("jwt.audience", "otpm-client") +} + +// validateConfig validates the configuration +func validateConfig(config *Config) error { + if config.Server.Port < 1 || config.Server.Port > 65535 { + return fmt.Errorf("invalid port number: %d", config.Server.Port) + } + + if config.Database.Driver == "" { + return fmt.Errorf("database driver is required") + } + + if config.Database.DSN == "" { + return fmt.Errorf("database DSN is required") + } + + if config.JWT.Secret == "" { + return fmt.Errorf("JWT secret is required") + } + + if config.WeChat.AppID == "" { + return fmt.Errorf("WeChat AppID is required") + } + + if config.WeChat.AppSecret == "" { + return fmt.Errorf("WeChat AppSecret is required") + } + + return nil +} diff --git a/database/database.go b/database/database.go deleted file mode 100644 index bf4ade8..0000000 --- a/database/database.go +++ /dev/null @@ -1,49 +0,0 @@ -package database - -import ( - _ "embed" - "log" - - _ "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx" - _ "github.com/lib/pq" - "github.com/spf13/viper" - _ "modernc.org/sqlite" -) - -var ( - //go:embed init/users.sql - userTable string - //go:embed init/otp.sql - otpTable string -) - -func InitDB() (*sqlx.DB, error) { - driver := viper.GetString("database.driver") - dsn := viper.GetString("database.dsn") - - db, err := sqlx.Open(driver, dsn) - if err != nil { - return nil, err - } - - if err := db.Ping(); err != nil { - return nil, err - } - - log.Println("Connected to database!") - return db, nil -} - -func MigrateDB(db *sqlx.DB) error { - _, err := db.Exec(userTable) - if err != nil { - return err - } - - _, err = db.Exec(otpTable) - if err != nil { - return err - } - return nil -} diff --git a/database/db.go b/database/db.go new file mode 100644 index 0000000..e882713 --- /dev/null +++ b/database/db.go @@ -0,0 +1,212 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + "time" + + "otpm/config" + + "github.com/jmoiron/sqlx" +) + +// DB wraps sqlx.DB to provide additional functionality +type DB struct { + *sqlx.DB +} + +// New creates a new database connection +func New(cfg *config.DatabaseConfig) (*DB, error) { + db, err := sqlx.Open(cfg.Driver, cfg.DSN) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + // Configure connection pool based on database type + if cfg.Driver == "sqlite3" { + // SQLite is a file-based database - simpler connection settings + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) // Connections don't need to be recycled + db.SetConnMaxIdleTime(0) + } else { + // For other databases (MySQL, PostgreSQL etc.) + db.SetMaxOpenConns(cfg.MaxOpenConns) + db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections + db.SetConnMaxLifetime(30 * time.Minute) + db.SetConnMaxIdleTime(5 * time.Minute) + } + + // Verify connection with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + log.Println("Successfully connected to database") + return &DB{db}, nil +} + +// WithTx executes a function within a transaction with retry logic +func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error { + var maxRetries int + var lastErr error + + // Adjust retry settings based on database type + if db.DriverName() == "sqlite3" { + maxRetries = 5 // SQLite needs more retries due to busy timeouts + } else { + maxRetries = 3 + } + + // Default transaction options + opts := &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + } + + for attempt := 1; attempt <= maxRetries; attempt++ { + start := time.Now() + + tx, err := db.BeginTxx(ctx, opts) + if err != nil { + if isRetryableError(err) && attempt < maxRetries { + lastErr = err + time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff + continue + } + return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err) + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } + }() + + if err := fn(tx); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err) + } + + if isRetryableError(err) && attempt < maxRetries { + lastErr = err + time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) + continue + } + return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err) + } + + // Log long-running transactions + if elapsed := time.Since(start); elapsed > 500*time.Millisecond { + log.Printf("Transaction completed in %v", elapsed) + } + + if err := tx.Commit(); err != nil { + if isRetryableError(err) && attempt < maxRetries { + lastErr = err + time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) + continue + } + return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err) + } + + return nil + } + + return lastErr +} + +// isRetryableError checks if an error is likely to succeed on retry +func isRetryableError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "deadlock") || + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "try again") || + strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "busy") || + strings.Contains(errStr, "locked") +} + +// ExecContext executes a query with adaptive timeout +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error { + // Set timeout based on query complexity + timeout := 5 * time.Second + if strings.Contains(strings.ToLower(query), "insert") || + strings.Contains(strings.ToLower(query), "update") || + strings.Contains(strings.ToLower(query), "delete") { + timeout = 10 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + start := time.Now() + _, err := db.DB.ExecContext(ctx, query, args...) + elapsed := time.Since(start) + + // Log slow queries + if elapsed > timeout/2 { + log.Printf("Slow query execution detected: %s (took %v)", query, elapsed) + } + + if err != nil { + return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err) + } + + return nil +} + +// QueryRowContext executes a query that returns a single row with timeout +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + return db.DB.QueryRowxContext(ctx, query, args...) +} + +// QueryContext executes a query that returns multiple rows with adaptive timeout +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + // Set timeout based on query complexity + timeout := 5 * time.Second + if strings.Contains(strings.ToLower(query), "join") || + strings.Contains(strings.ToLower(query), "group by") || + strings.Contains(strings.ToLower(query), "order by") { + timeout = 15 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + start := time.Now() + rows, err := db.DB.QueryxContext(ctx, query, args...) + elapsed := time.Since(start) + + // Log slow queries + if elapsed > timeout/2 { + log.Printf("Slow query detected: %s (took %v)", query, elapsed) + } + + if err != nil { + return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err) + } + + return rows, nil +} + +// Close closes the database connection +func (db *DB) Close() error { + if err := db.DB.Close(); err != nil { + return fmt.Errorf("failed to close database connection: %w", err) + } + return nil +} diff --git a/database/init/otp.sql b/database/init/otp.sql index a9b2442..e0e1ba9 100644 --- a/database/init/otp.sql +++ b/database/init/otp.sql @@ -1,6 +1,26 @@ CREATE TABLE IF NOT EXISTS otp ( - id SERIAL PRIMARY KEY, - openid VARCHAR(255), - num INTEGER, - token VARCHAR(255) -); \ No newline at end of file + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id VARCHAR(255) NOT NULL, + openid VARCHAR(255) NOT NULL, + name VARCHAR(100) NOT NULL, + issuer VARCHAR(255), + secret VARCHAR(255) NOT NULL, + algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1', + digits INTEGER NOT NULL DEFAULT 6, + period INTEGER NOT NULL DEFAULT 30, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, name), + UNIQUE(openid) +); + +-- Add index for faster lookups +CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id); +CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid); + +-- Trigger to update the updated_at timestamp +CREATE TRIGGER IF NOT EXISTS update_otp_timestamp + AFTER UPDATE ON otp +BEGIN + UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; +END; \ No newline at end of file diff --git a/database/init/users.sql b/database/init/users.sql index 4dbedae..ae4532a 100644 --- a/database/init/users.sql +++ b/database/init/users.sql @@ -1,5 +1,6 @@ CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, + id INTEGER PRIMARY KEY AUTOINCREMENT, openid VARCHAR(255) UNIQUE NOT NULL, session_key VARCHAR(255) UNIQUE NOT NULL -); \ No newline at end of file +); +CREATE UNIQUE INDEX idx_users_openid ON users(openid); \ No newline at end of file diff --git a/database/migration.go b/database/migration.go new file mode 100644 index 0000000..df15a97 --- /dev/null +++ b/database/migration.go @@ -0,0 +1,160 @@ +package database + +import ( + "context" + "fmt" + "log" + "time" + + _ "embed" + + "github.com/jmoiron/sqlx" +) + +var ( + //go:embed init/users.sql + userTable string + //go:embed init/otp.sql + otpTable string +) + +// Migration represents a database migration +type Migration struct { + Name string + SQL string + Version int +} + +// Migrations is a list of all migrations +var Migrations = []Migration{ + { + Name: "Create users table", + SQL: userTable, + Version: 1, + }, + { + Name: "Create OTP table", + SQL: otpTable, + Version: 2, + }, +} + +// MigrationRecord represents a record in the migrations table +type MigrationRecord struct { + ID int `db:"id"` + Version int `db:"version"` + Name string `db:"name"` + AppliedAt time.Time `db:"applied_at"` +} + +// ensureMigrationsTable ensures that the migrations table exists +func ensureMigrationsTable(db *sqlx.DB) error { + query := ` + CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version INTEGER NOT NULL, + name TEXT NOT NULL, + applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + ` + + _, err := db.Exec(query) + if err != nil { + return fmt.Errorf("failed to create migrations table: %w", err) + } + + return nil +} + +// getAppliedMigrations gets all applied migrations +func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) { + var records []MigrationRecord + if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil { + return nil, fmt.Errorf("failed to get applied migrations: %w", err) + } + + result := make(map[int]MigrationRecord) + for _, record := range records { + result[record.Version] = record + } + + return result, nil +} + +// Migrate runs all pending migrations +func Migrate(db *sqlx.DB, skipMigration bool) error { + if skipMigration { + log.Println("Skipping database migration as configured") + return nil + } + + // Ensure migrations table exists + if err := ensureMigrationsTable(db); err != nil { + return err + } + + // Get applied migrations + appliedMigrations, err := getAppliedMigrations(db) + if err != nil { + return err + } + + // Apply pending migrations + for _, migration := range Migrations { + if _, ok := appliedMigrations[migration.Version]; ok { + log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name) + continue + } + + log.Printf("Applying migration %d: %s", migration.Version, migration.Name) + + // Start a transaction for this migration + tx, err := db.Beginx() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + // Execute migration + if _, err := tx.Exec(migration.SQL); err != nil { + tx.Rollback() + return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err) + } + + // Record migration + if _, err := tx.Exec( + "INSERT INTO migrations (version, name) VALUES (?, ?)", + migration.Version, migration.Name, + ); err != nil { + tx.Rollback() + return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err) + } + + // Commit transaction + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err) + } + + log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name) + } + + return nil +} + +// MigrateWithContext runs all pending migrations with context +func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error { + // Create a channel to signal completion + done := make(chan error, 1) + + // Run migration in a goroutine + go func() { + done <- Migrate(db, skipMigration) + }() + + // Wait for migration to complete or context to be canceled + select { + case err := <-done: + return err + case <-ctx.Done(): + return fmt.Errorf("migration canceled: %w", ctx.Err()) + } +} diff --git a/docs/swagger.go b/docs/swagger.go new file mode 100644 index 0000000..ccf05ac --- /dev/null +++ b/docs/swagger.go @@ -0,0 +1,663 @@ +// Package docs provides API documentation using Swagger/OpenAPI +package docs + +import ( + "encoding/json" + "net/http" +) + +// SwaggerInfo holds the API information +var SwaggerInfo = struct { + Title string + Description string + Version string + Host string + BasePath string + Schemes []string +}{ + Title: "OTPM API", + Description: "API for One-Time Password Manager", + Version: "1.0.0", + Host: "localhost:8080", + BasePath: "/", + Schemes: []string{"http", "https"}, +} + +// SwaggerJSON returns the Swagger JSON +func SwaggerJSON() []byte { + swagger := map[string]interface{}{ + "swagger": "2.0", + "info": map[string]interface{}{ + "title": SwaggerInfo.Title, + "description": SwaggerInfo.Description, + "version": SwaggerInfo.Version, + }, + "host": SwaggerInfo.Host, + "basePath": SwaggerInfo.BasePath, + "schemes": SwaggerInfo.Schemes, + "paths": getPaths(), + "definitions": map[string]interface{}{ + "LoginRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "WeChat authorization code", + }, + }, + "required": []string{"code"}, + }, + "LoginResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "token": map[string]interface{}{ + "type": "string", + "description": "JWT token", + }, + "openid": map[string]interface{}{ + "type": "string", + "description": "WeChat OpenID", + }, + }, + }, + "CreateOTPRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "OTP name", + }, + "issuer": map[string]interface{}{ + "type": "string", + "description": "OTP issuer", + }, + "secret": map[string]interface{}{ + "type": "string", + "description": "OTP secret", + }, + "algorithm": map[string]interface{}{ + "type": "string", + "description": "OTP algorithm", + "enum": []string{"SHA1", "SHA256", "SHA512"}, + }, + "digits": map[string]interface{}{ + "type": "integer", + "description": "OTP digits", + "enum": []int{6, 8}, + }, + "period": map[string]interface{}{ + "type": "integer", + "description": "OTP period in seconds", + "enum": []int{30, 60}, + }, + }, + "required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"}, + }, + "OTP": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "OTP ID", + }, + "user_id": map[string]interface{}{ + "type": "string", + "description": "User ID", + }, + "name": map[string]interface{}{ + "type": "string", + "description": "OTP name", + }, + "issuer": map[string]interface{}{ + "type": "string", + "description": "OTP issuer", + }, + "algorithm": map[string]interface{}{ + "type": "string", + "description": "OTP algorithm", + }, + "digits": map[string]interface{}{ + "type": "integer", + "description": "OTP digits", + }, + "period": map[string]interface{}{ + "type": "integer", + "description": "OTP period in seconds", + }, + "created_at": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Creation time", + }, + "updated_at": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Last update time", + }, + }, + }, + "OTPCodeResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "OTP code", + }, + "expires_in": map[string]interface{}{ + "type": "integer", + "description": "Seconds until expiration", + }, + }, + }, + "VerifyOTPRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "OTP code to verify", + }, + }, + "required": []string{"code"}, + }, + "VerifyOTPResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "valid": map[string]interface{}{ + "type": "boolean", + "description": "Whether the code is valid", + }, + }, + }, + "UpdateOTPRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "OTP name", + }, + "issuer": map[string]interface{}{ + "type": "string", + "description": "OTP issuer", + }, + "algorithm": map[string]interface{}{ + "type": "string", + "description": "OTP algorithm", + "enum": []string{"SHA1", "SHA256", "SHA512"}, + }, + "digits": map[string]interface{}{ + "type": "integer", + "description": "OTP digits", + "enum": []int{6, 8}, + }, + "period": map[string]interface{}{ + "type": "integer", + "description": "OTP period in seconds", + "enum": []int{30, 60}, + }, + }, + }, + "ErrorResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "integer", + "description": "Error code", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "Error message", + }, + }, + }, + }, + "securityDefinitions": map[string]interface{}{ + "Bearer": map[string]interface{}{ + "type": "apiKey", + "name": "Authorization", + "in": "header", + "description": "JWT token with Bearer prefix", + }, + }, + } + + data, _ := json.MarshalIndent(swagger, "", " ") + return data +} + +// getPaths returns the API paths +func getPaths() map[string]interface{} { + return map[string]interface{}{ + "/login": map[string]interface{}{ + "post": map[string]interface{}{ + "summary": "Login with WeChat", + "description": "Login with WeChat authorization code", + "tags": []string{"auth"}, + "parameters": []map[string]interface{}{ + { + "name": "body", + "in": "body", + "description": "Login request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/LoginRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Successful login", + "schema": map[string]interface{}{ + "$ref": "#/definitions/LoginResponse", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/verify-token": map[string]interface{}{ + "post": map[string]interface{}{ + "summary": "Verify token", + "description": "Verify JWT token", + "tags": []string{"auth"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Token is valid", + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "valid": map[string]interface{}{ + "type": "boolean", + "description": "Whether the token is valid", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "List OTPs", + "description": "List all OTPs for the authenticated user", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "List of OTPs", + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "$ref": "#/definitions/OTP", + }, + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + "post": map[string]interface{}{ + "summary": "Create OTP", + "description": "Create a new OTP", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "body", + "in": "body", + "description": "OTP creation request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/CreateOTPRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP created", + "schema": map[string]interface{}{ + "$ref": "#/definitions/OTP", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp/{id}": map[string]interface{}{ + "put": map[string]interface{}{ + "summary": "Update OTP", + "description": "Update an existing OTP", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + { + "name": "body", + "in": "body", + "description": "OTP update request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/UpdateOTPRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP updated", + "schema": map[string]interface{}{ + "$ref": "#/definitions/OTP", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + "delete": map[string]interface{}{ + "summary": "Delete OTP", + "description": "Delete an existing OTP", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP deleted", + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "Success message", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp/{id}/code": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Get OTP code", + "description": "Get the current OTP code", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP code", + "schema": map[string]interface{}{ + "$ref": "#/definitions/OTPCodeResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/otp/{id}/verify": map[string]interface{}{ + "post": map[string]interface{}{ + "summary": "Verify OTP code", + "description": "Verify an OTP code", + "tags": []string{"otp"}, + "security": []map[string]interface{}{ + { + "Bearer": []string{}, + }, + }, + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "description": "OTP ID", + "required": true, + "type": "string", + }, + { + "name": "body", + "in": "body", + "description": "OTP verification request", + "required": true, + "schema": map[string]interface{}{ + "$ref": "#/definitions/VerifyOTPRequest", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "OTP verification result", + "schema": map[string]interface{}{ + "$ref": "#/definitions/VerifyOTPResponse", + }, + }, + "400": map[string]interface{}{ + "description": "Invalid request", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "401": map[string]interface{}{ + "description": "Unauthorized", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "404": map[string]interface{}{ + "description": "OTP not found", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + "500": map[string]interface{}{ + "description": "Internal server error", + "schema": map[string]interface{}{ + "$ref": "#/definitions/ErrorResponse", + }, + }, + }, + }, + }, + "/health": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Health check", + "description": "Check if the API is healthy", + "tags": []string{"system"}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "API is healthy", + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "status": map[string]interface{}{ + "type": "string", + "description": "Health status", + }, + "time": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Current time", + }, + }, + }, + }, + }, + }, + }, + "/metrics": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Metrics", + "description": "Get application metrics", + "tags": []string{"system"}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Application metrics", + }, + }, + }, + }, + "/swagger.json": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Swagger JSON", + "description": "Get Swagger JSON", + "tags": []string{"system"}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Swagger JSON", + }, + }, + }, + }, + } +} + +// Handler returns an HTTP handler for Swagger JSON +func Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(SwaggerJSON()) + } +} diff --git a/go.mod b/go.mod index 57bb068..a55608f 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,36 @@ module otpm -go 1.21.1 +go 1.23.0 + +toolchain go1.23.9 require ( - github.com/go-sql-driver/mysql v1.8.1 + github.com/go-playground/validator/v10 v10.26.0 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.6.0 github.com/jmoiron/sqlx v1.4.0 github.com/julienschmidt/httprouter v1.3.0 - github.com/lib/pq v1.10.9 - github.com/spf13/cobra v1.8.1 + github.com/prometheus/client_golang v1.22.0 github.com/spf13/viper v1.19.0 - modernc.org/sqlite v1.32.0 + golang.org/x/crypto v0.38.0 ) require ( - filippo.io/edwards25519 v1.1.0 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/gabriel-vasile/mimetype v1.4.8 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -36,14 +41,10 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect + google.golang.org/protobuf v1.36.5 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect - modernc.org/libc v1.55.3 // indirect - modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.8.0 // indirect - modernc.org/strutil v1.2.0 // indirect - modernc.org/token v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index 5dc30bc..c30de80 100644 --- a/go.sum +++ b/go.sum @@ -1,60 +1,76 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= +github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= +github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= -github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= -github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -65,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= -github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= @@ -79,56 +93,32 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= -modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= -modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= -modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s= -modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= -modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= -modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= -modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= -modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= -modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= -modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= -modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= -modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= -modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= -modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= -modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= -modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= -modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= -modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s= -modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA= -modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= -modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= -modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= -modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go new file mode 100644 index 0000000..f79979f --- /dev/null +++ b/handlers/auth_handler.go @@ -0,0 +1,156 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "time" + + "otpm/api" + "otpm/services" + + "github.com/golang-jwt/jwt" + "github.com/julienschmidt/httprouter" +) + +// AuthHandler handles authentication related requests +type AuthHandler struct { + authService *services.AuthService +} + +// NewAuthHandler creates a new AuthHandler +func NewAuthHandler(authService *services.AuthService) *AuthHandler { + return &AuthHandler{ + authService: authService, + } +} + +// LoginRequest represents a login request +type LoginRequest struct { + Code string `json:"code" validate:"required,min=32,max=128"` +} + +// LoginResponse represents a login response +type LoginResponse struct { + Token string `json:"token"` + OpenID string `json:"openid"` +} + +// TokenRequest represents a token verification request +type TokenRequest struct { + Token string `validate:"required,min=32"` +} + +// Login handles WeChat login +func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + start := time.Now() + + // Limit request body size to prevent DOS + r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request + + // Parse and validate request + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + fmt.Sprintf("Invalid request body: %v", err)) + log.Printf("Login request parse error: %v", err) + return + } + + // Validate using validator + if err := api.Validate.Struct(req); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + fmt.Sprintf("Invalid request parameters: %v", err)) + log.Printf("Login request validation failed: %v", err) + return + } + + // Login with WeChat code + token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + log.Printf("Login failed for code %s: %v", req.Code, err) + return + } + + // Log successful login + log.Printf("Login successful for code %s (took %v)", + req.Code, time.Since(start)) + + // Return token + api.NewResponseWriter(w).WriteSuccess(LoginResponse{ + Token: token, + }) +} + +// VerifyToken handles token verification +func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + start := time.Now() + + // Get token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Authorization header is required") + log.Printf("Token verification failed: missing Authorization header") + return + } + + // Validate token format + if len(authHeader) < 7 || authHeader[:7] != "Bearer " { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Invalid token format. Expected 'Bearer '") + log.Printf("Token verification failed: invalid token format") + return + } + + token := authHeader[7:] + + // Validate token using validator + tokenReq := TokenRequest{Token: token} + if err := api.Validate.Struct(tokenReq); err != nil { + api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, + "Invalid token format") + log.Printf("Token verification failed: %v", err) + return + } + + // Validate token + claims, err := h.authService.ValidateToken(token) + if err != nil { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + log.Printf("Token verification failed for token %s: %v", + maskToken(token), err) // Mask token in logs + return + } + + // Log successful verification + userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string) + if !ok { + log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start)) + } else { + log.Printf("Token verified for user %s (took %v)", userID, time.Since(start)) + } + + // Token is valid + api.NewResponseWriter(w).WriteSuccess(map[string]bool{ + "valid": true, + }) +} + +// maskToken masks sensitive parts of token for logging +func maskToken(token string) string { + if len(token) < 8 { + return "****" + } + return token[:4] + "****" + token[len(token)-4:] +} + +// Routes returns all routes for the auth handler +func (h *AuthHandler) Routes() map[string]httprouter.Handle { + return map[string]httprouter.Handle{ + "/api/login": h.Login, + "/api/verify-token": h.VerifyToken, + } +} diff --git a/handlers/handler.go b/handlers/handler.go deleted file mode 100644 index db15938..0000000 --- a/handlers/handler.go +++ /dev/null @@ -1,9 +0,0 @@ -package handlers - -import ( - "github.com/jmoiron/sqlx" -) - -type Handler struct { - DB *sqlx.DB -} diff --git a/handlers/login.go b/handlers/login.go deleted file mode 100644 index 30812d7..0000000 --- a/handlers/login.go +++ /dev/null @@ -1,93 +0,0 @@ -package handlers - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - - "github.com/spf13/viper" -) - -var code2Session = "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code" - -type LoginRequest struct { - Code string `json:"code"` -} - -// 封装code2session接口返回数据 -type LoginResponse struct { - OpenId string `json:"openid"` - SessionKey string `json:"session_key"` - UnionId string `json:"unionid"` - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` -} - -func getLoginResponse(code string) (*LoginResponse, error) { - appid := viper.GetString("wechat.appid") - secret := viper.GetString("wechat.secret") - url := fmt.Sprintf(code2Session, appid, secret, code) - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var loginResponse LoginResponse - if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil { - return nil, err - } - - if loginResponse.ErrCode != 0 { - return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg) - } - return &loginResponse, nil -} - -func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { - - var req LoginRequest - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - if err := json.Unmarshal(body, &req); err != nil { - http.Error(w, "Failed to parse request body", http.StatusBadRequest) - return - } - - loginResponse, err := getLoginResponse(req.Code) - if err != nil { - http.Error(w, "Failed to get session key", http.StatusInternalServerError) - return - } - - // // 插入或更新用户的openid和sessionid - // query := ` - // INSERT INTO users (openid, sessionid) - // VALUES ($1, $2) - // ON CONFLICT (openid) DO UPDATE SET sessionid = $2 - // RETURNING id; - // ` - - // var ID int - // if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil { - // http.Error(w, "Failed to log in user", http.StatusInternalServerError) - // return - // } - - data := map[string]interface{}{ - "openid": loginResponse.OpenId, - "session_key": loginResponse.SessionKey, - } - respData, err := json.Marshal(data) - if err != nil { - http.Error(w, "Failed to marshal response data", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(respData)) -} diff --git a/handlers/otp.go b/handlers/otp.go deleted file mode 100644 index 61230ec..0000000 --- a/handlers/otp.go +++ /dev/null @@ -1,61 +0,0 @@ -package handlers - -import ( - "encoding/json" - "net/http" -) - -type OtpRequest struct { - OpenID string `json:"openid"` - Num int `json:"num"` - Token *[]OTP `json:"token"` -} -type OTP struct { - Issuer string `json:"issuer"` - Remark string `json:"remark"` - Secret string `json:"secret"` -} - -func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) { - var req OtpRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request payload", http.StatusBadRequest) - return - } - - num := len(*req.Token) - - // 插入或更新 OTP 记录 - query := ` - INSERT INTO otp (openid, num, token) - VALUES ($1, $2, $3) - ` - - _, err := h.DB.Exec(query, req.OpenID, req.Token, num) - if err != nil { - http.Error(w, "Failed to update or create OTP", http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusOK) - w.Write([]byte("OTP updated or created successfully")) -} - -func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) { - openid := r.URL.Query().Get("openid") - if openid == "" { - http.Error(w, "未登录", http.StatusBadRequest) - return - } - - var otp OtpRequest - - err := h.DB.Get(&otp, "SELECT token, num, openid FROM otp WHERE openid=$1", openid) - if err != nil { - http.Error(w, "OTP not found", http.StatusNotFound) - return - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(otp) -} diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go new file mode 100644 index 0000000..c1d99bb --- /dev/null +++ b/handlers/otp_handler.go @@ -0,0 +1,114 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/julienschmidt/httprouter" + + "otpm/api" + "otpm/middleware" + "otpm/models" + "otpm/services" +) + +// OTPHandler handles OTP-related HTTP requests +type OTPHandler struct { + otpService *services.OTPService +} + +// NewOTPHandler creates a new OTPHandler +func NewOTPHandler(otpService *services.OTPService) *OTPHandler { + return &OTPHandler{ + otpService: otpService, + } +} + +// Routes returns the routes for OTP operations +func (h *OTPHandler) Routes() map[string]httprouter.Handle { + return map[string]httprouter.Handle{ + "POST /api/otp": h.CreateOTP, + "GET /api/otps": h.ListOTPs, + "GET /api/otp/:id": h.GetOTP, + } +} + +// CreateOTP handles the creation of a new OTP +func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + // Get user ID from context + userID, ok := r.Context().Value(middleware.UserIDKey).(string) + if !ok { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Parse request body + var params models.OTPParams + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { + api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body")) + return + } + + // Validate request + if err := api.Validate.Struct(params); err != nil { + api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error())) + return + } + + // Create OTP + otp, err := h.otpService.CreateOTP(r.Context(), userID, params) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + // Return response + api.NewResponseWriter(w).WriteSuccess(otp) +} + +// ListOTPs handles listing all OTPs for a user +func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + // Get user ID from context + userID, ok := r.Context().Value(middleware.UserIDKey).(string) + if !ok { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Get OTPs + otps, err := h.otpService.ListOTPs(r.Context(), userID) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + // Return response + api.NewResponseWriter(w).WriteSuccess(otps) +} + +// GetOTP handles getting a specific OTP +func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + // Get user ID from context + userID, ok := r.Context().Value(middleware.UserIDKey).(string) + if !ok { + api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) + return + } + + // Get OTP ID from URL + otpID := ps.ByName("id") + if otpID == "" { + api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID")) + return + } + + // Get OTP + otp, err := h.otpService.GetOTP(r.Context(), otpID, userID) + if err != nil { + api.NewResponseWriter(w).WriteError(api.InternalError(err)) + return + } + + // Return response + api.NewResponseWriter(w).WriteSuccess(otp) +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..d7d76e4 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,204 @@ +package logger + +import ( + "context" + "fmt" + "io" + "os" + "runtime" + "strings" + "time" + + "github.com/google/uuid" +) + +// Level represents a log level +type Level int + +const ( + // DEBUG level + DEBUG Level = iota + // INFO level + INFO + // WARN level + WARN + // ERROR level + ERROR + // FATAL level + FATAL +) + +// String returns the string representation of the log level +func (l Level) String() string { + switch l { + case DEBUG: + return "DEBUG" + case INFO: + return "INFO" + case WARN: + return "WARN" + case ERROR: + return "ERROR" + case FATAL: + return "FATAL" + default: + return "UNKNOWN" + } +} + +// Logger represents a logger +type Logger struct { + level Level + output io.Writer +} + +// contextKey is a type for context keys +type contextKey string + +// requestIDKey is the key for request ID in context +const requestIDKey = contextKey("request_id") + +// New creates a new logger +func New(level Level, output io.Writer) *Logger { + if output == nil { + output = os.Stdout + } + return &Logger{ + level: level, + output: output, + } +} + +// WithLevel creates a new logger with the specified level +func (l *Logger) WithLevel(level Level) *Logger { + return &Logger{ + level: level, + output: l.output, + } +} + +// WithOutput creates a new logger with the specified output +func (l *Logger) WithOutput(output io.Writer) *Logger { + return &Logger{ + level: l.level, + output: output, + } +} + +// log logs a message with the specified level +func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) { + if level < l.level { + return + } + + // Get request ID from context + requestID := getRequestID(ctx) + + // Get caller information + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "unknown" + line = 0 + } + // Extract just the filename + if idx := strings.LastIndex(file, "/"); idx >= 0 { + file = file[idx+1:] + } + + // Format message + message := fmt.Sprintf(format, args...) + + // Format log entry + timestamp := time.Now().Format(time.RFC3339) + logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n", + timestamp, level.String(), file, line, requestID, message) + + // Write log entry + _, _ = l.output.Write([]byte(logEntry)) + + // Exit if fatal + if level == FATAL { + os.Exit(1) + } +} + +// Debug logs a debug message +func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, DEBUG, format, args...) +} + +// Info logs an info message +func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, INFO, format, args...) +} + +// Warn logs a warning message +func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, WARN, format, args...) +} + +// Error logs an error message +func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, ERROR, format, args...) +} + +// Fatal logs a fatal message and exits +func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) { + l.log(ctx, FATAL, format, args...) +} + +// WithRequestID adds a request ID to the context +func WithRequestID(ctx context.Context) context.Context { + requestID := uuid.New().String() + return context.WithValue(ctx, requestIDKey, requestID) +} + +// GetRequestID gets the request ID from the context +func GetRequestID(ctx context.Context) string { + return getRequestID(ctx) +} + +// getRequestID gets the request ID from the context +func getRequestID(ctx context.Context) string { + if ctx == nil { + return "-" + } + requestID, ok := ctx.Value(requestIDKey).(string) + if !ok { + return "-" + } + return requestID +} + +// Default logger +var defaultLogger = New(INFO, os.Stdout) + +// SetDefaultLogger sets the default logger +func SetDefaultLogger(logger *Logger) { + defaultLogger = logger +} + +// Debug logs a debug message using the default logger +func Debug(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Debug(ctx, format, args...) +} + +// Info logs an info message using the default logger +func Info(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Info(ctx, format, args...) +} + +// Warn logs a warning message using the default logger +func Warn(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Warn(ctx, format, args...) +} + +// Error logs an error message using the default logger +func Error(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Error(ctx, format, args...) +} + +// Fatal logs a fatal message and exits using the default logger +func Fatal(ctx context.Context, format string, args ...interface{}) { + defaultLogger.Fatal(ctx, format, args...) +} diff --git a/metrics/metrics.go b/metrics/metrics.go new file mode 100644 index 0000000..8e1d18a --- /dev/null +++ b/metrics/metrics.go @@ -0,0 +1,193 @@ +package metrics + +import ( + "fmt" + "net/http" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + // Default metrics + requestDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "Duration of HTTP requests in seconds", + Buckets: prometheus.DefBuckets, + }, + []string{"method", "path", "status"}, + ) + + requestTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "Total number of HTTP requests", + }, + []string{"method", "path", "status"}, + ) + + otpGenerationTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "otp_generation_total", + Help: "Total number of OTP generations", + }, + []string{"user_id", "otp_id"}, + ) + + otpVerificationTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "otp_verification_total", + Help: "Total number of OTP verifications", + }, + []string{"user_id", "otp_id", "success"}, + ) + + activeUsers = prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "active_users", + Help: "Number of active users", + }, + ) + + cacheHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cache_hits_total", + Help: "Total number of cache hits", + }, + []string{"cache"}, + ) + + cacheMisses = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cache_misses_total", + Help: "Total number of cache misses", + }, + []string{"cache"}, + ) +) + +func init() { + // Register metrics with prometheus + prometheus.MustRegister( + requestDuration, + requestTotal, + otpGenerationTotal, + otpVerificationTotal, + activeUsers, + cacheHits, + cacheMisses, + ) +} + +// MetricsService provides metrics functionality +type MetricsService struct { + activeUsersMutex sync.RWMutex + activeUserIDs map[string]bool +} + +// NewMetricsService creates a new MetricsService +func NewMetricsService() *MetricsService { + return &MetricsService{ + activeUserIDs: make(map[string]bool), + } +} + +// Handler returns an HTTP handler for metrics +func (s *MetricsService) Handler() http.Handler { + return promhttp.Handler() +} + +// RecordRequest records metrics for an HTTP request +func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) { + labels := prometheus.Labels{ + "method": method, + "path": path, + "status": fmt.Sprintf("%d", status), + } + + requestDuration.With(labels).Observe(duration.Seconds()) + requestTotal.With(labels).Inc() +} + +// RecordOTPGeneration records metrics for OTP generation +func (s *MetricsService) RecordOTPGeneration(userID, otpID string) { + otpGenerationTotal.With(prometheus.Labels{ + "user_id": userID, + "otp_id": otpID, + }).Inc() +} + +// RecordOTPVerification records metrics for OTP verification +func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) { + otpVerificationTotal.With(prometheus.Labels{ + "user_id": userID, + "otp_id": otpID, + "success": fmt.Sprintf("%t", success), + }).Inc() +} + +// RecordUserActivity records user activity +func (s *MetricsService) RecordUserActivity(userID string) { + s.activeUsersMutex.Lock() + defer s.activeUsersMutex.Unlock() + + if !s.activeUserIDs[userID] { + s.activeUserIDs[userID] = true + activeUsers.Inc() + } +} + +// RecordUserInactivity records user inactivity +func (s *MetricsService) RecordUserInactivity(userID string) { + s.activeUsersMutex.Lock() + defer s.activeUsersMutex.Unlock() + + if s.activeUserIDs[userID] { + delete(s.activeUserIDs, userID) + activeUsers.Dec() + } +} + +// RecordCacheHit records a cache hit +func (s *MetricsService) RecordCacheHit(cache string) { + cacheHits.With(prometheus.Labels{ + "cache": cache, + }).Inc() +} + +// RecordCacheMiss records a cache miss +func (s *MetricsService) RecordCacheMiss(cache string) { + cacheMisses.With(prometheus.Labels{ + "cache": cache, + }).Inc() +} + +// Middleware creates a middleware that records request metrics +func (s *MetricsService) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Create response writer that captures status code + rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} + + // Call next handler + next.ServeHTTP(rw, r) + + // Record metrics + s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start)) + }) +} + +// responseWriter wraps http.ResponseWriter to capture status code +type responseWriter struct { + http.ResponseWriter + status int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.status = code + rw.ResponseWriter.WriteHeader(code) +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..6afce13 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,353 @@ +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "log" + "math/rand" + "net/http" + "runtime/debug" + "strings" + "time" + + "github.com/golang-jwt/jwt" +) + +// Response represents a standard API response +type Response struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ErrorResponse sends a JSON error response +func ErrorResponse(w http.ResponseWriter, code int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(Response{ + Code: code, + Message: message, + }) +} + +// SuccessResponse sends a JSON success response +func SuccessResponse(w http.ResponseWriter, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(Response{ + Code: http.StatusOK, + Message: "success", + Data: data, + }) +} + +// Logger is a middleware that logs request details with structured format +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + r.Header.Set("X-Request-ID", requestID) + } + + // Create a custom response writer to capture status code + rw := &responseWriter{ + ResponseWriter: w, + status: http.StatusOK, + } + + // Process request + next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID))) + + // Log structured request details + log.Printf( + "method=%s path=%s status=%d duration=%s ip=%s request_id=%s", + r.Method, + r.URL.Path, + rw.status, + time.Since(start).String(), + r.RemoteAddr, + requestID, + ) + }) +} + +// generateRequestID creates a unique request identifier +func generateRequestID() string { + return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) +} + +// randomString generates a random string of given length +func randomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} + +// Recover is a middleware that recovers from panics with detailed logging +func Recover(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + // Get request ID from context + requestID := "" + if ctx := r.Context(); ctx != nil { + if id, ok := ctx.Value("request_id").(string); ok { + requestID = id + } + } + + // Log error with stack trace and request context + log.Printf( + "panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s", + err, + requestID, + r.Method, + r.URL.Path, + r.RemoteAddr, + debug.Stack(), + ) + + // Determine error type + var message string + var status int + + switch e := err.(type) { + case error: + message = e.Error() + if isClientError(e) { + status = http.StatusBadRequest + } else { + status = http.StatusInternalServerError + } + case string: + message = e + status = http.StatusInternalServerError + default: + message = "Internal Server Error" + status = http.StatusInternalServerError + } + + ErrorResponse(w, status, message) + } + }() + next.ServeHTTP(w, r) + }) +} + +// isClientError checks if error should be treated as client error +func isClientError(err error) bool { + // Add more client error types as needed + return strings.Contains(err.Error(), "validation") || + strings.Contains(err.Error(), "invalid") || + strings.Contains(err.Error(), "missing") +} + +// CORS is a middleware that handles CORS +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// Timeout is a middleware that safely handles request timeouts +func Timeout(duration time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), duration) + defer cancel() + + // Use buffered channels to prevent goroutine leaks + done := make(chan struct{}, 1) + panicChan := make(chan interface{}, 1) + + // Track request processing in goroutine + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + next.ServeHTTP(w, r.WithContext(ctx)) + done <- struct{}{} + }() + + // Wait for completion, timeout or panic + select { + case <-done: + return + case p := <-panicChan: + panic(p) // Re-throw panic to be caught by Recover middleware + case <-ctx.Done(): + // Get request context for logging + requestID := "" + if ctx := r.Context(); ctx != nil { + if id, ok := ctx.Value("request_id").(string); ok { + requestID = id + } + } + + // Log timeout details + log.Printf( + "request_timeout: request_id=%s method=%s path=%s timeout=%s", + requestID, + r.Method, + r.URL.Path, + duration.String(), + ) + + // Send timeout response + ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf( + "Request timed out after %s", duration.String(), + )) + } + }) + } +} + +// Auth is a middleware that validates JWT tokens with enhanced security +func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get request ID for logging + requestID := "" + if ctx := r.Context(); ctx != nil { + if id, ok := ctx.Value("request_id").(string); ok { + requestID = id + } + } + + // Get token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required") + return + } + + // Validate header format + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer '") + return + } + + tokenString := parts[1] + + // Parse and validate token + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Validate signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(jwtSecret), nil + }) + + if err != nil { + log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err) + ErrorResponse(w, http.StatusUnauthorized, "Invalid token") + return + } + + if !token.Valid { + log.Printf("auth_failed: request_id=%s error=invalid_token", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Invalid token") + return + } + + // Validate claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims") + return + } + + // Check required claims + userID, ok := claims["user_id"].(string) + if !ok || userID == "" { + log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token") + return + } + + // Check token expiration + if exp, ok := claims["exp"].(float64); ok { + if time.Now().Unix() > int64(exp) { + log.Printf("auth_failed: request_id=%s error=token_expired", requestID) + ErrorResponse(w, http.StatusUnauthorized, "Token has expired") + return + } + } + + // Check required roles if specified + if len(requiredRoles) > 0 { + roles, ok := claims["roles"].([]interface{}) + if !ok { + log.Printf("auth_failed: request_id=%s error=missing_roles", requestID) + ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles") + return + } + + hasRequiredRole := false + for _, requiredRole := range requiredRoles { + for _, role := range roles { + if r, ok := role.(string); ok && r == requiredRole { + hasRequiredRole = true + break + } + } + } + + if !hasRequiredRole { + log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID) + ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions") + return + } + } + + // Add claims to context + ctx := r.Context() + ctx = context.WithValue(ctx, "user_id", userID) + ctx = context.WithValue(ctx, "claims", claims) + + log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// responseWriter is a custom response writer that captures the status code +type responseWriter struct { + http.ResponseWriter + status int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.status = code + rw.ResponseWriter.WriteHeader(code) +} + +// GetUserID gets the user ID from the request context +func GetUserID(r *http.Request) (string, error) { + userID, ok := r.Context().Value("user_id").(string) + if !ok { + return "", fmt.Errorf("user ID not found in context") + } + return userID, nil +} diff --git a/miniprogram-example/app.js b/miniprogram-example/app.js new file mode 100644 index 0000000..34eba08 --- /dev/null +++ b/miniprogram-example/app.js @@ -0,0 +1,50 @@ +// app.js +App({ + onLaunch() { + // 检查更新 + if (wx.canIUse('getUpdateManager')) { + const updateManager = wx.getUpdateManager(); + updateManager.onCheckForUpdate(function (res) { + if (res.hasUpdate) { + updateManager.onUpdateReady(function () { + wx.showModal({ + title: '更新提示', + content: '新版本已经准备好,是否重启应用?', + success: function (res) { + if (res.confirm) { + updateManager.applyUpdate(); + } + } + }); + }); + + updateManager.onUpdateFailed(function () { + wx.showModal({ + title: '更新提示', + content: '新版本下载失败,请检查网络后重试', + showCancel: false + }); + }); + } + }); + } + + // 获取系统信息 + try { + const systemInfo = wx.getSystemInfoSync(); + this.globalData.systemInfo = systemInfo; + + // 计算安全区域 + const { screenHeight, safeArea } = systemInfo; + this.globalData.safeAreaBottom = screenHeight - safeArea.bottom; + } catch (e) { + console.error('获取系统信息失败', e); + } + }, + + globalData: { + userInfo: null, + systemInfo: {}, + safeAreaBottom: 0 + } +}); \ No newline at end of file diff --git a/miniprogram-example/app.json b/miniprogram-example/app.json new file mode 100644 index 0000000..89f79ae --- /dev/null +++ b/miniprogram-example/app.json @@ -0,0 +1,23 @@ +{ + "pages": [ + "pages/login/login", + "pages/otp-list/index", + "pages/otp-add/index" + ], + "window": { + "backgroundTextStyle": "light", + "navigationBarBackgroundColor": "#fff", + "navigationBarTitleText": "OTPM", + "navigationBarTextStyle": "black", + "backgroundColor": "#F8F8F8" + }, + "permission": { + "scope.camera": { + "desc": "需要使用相机扫描二维码" + } + }, + "usingComponents": {}, + "style": "v2", + "sitemapLocation": "sitemap.json", + "lazyCodeLoading": "requiredComponents" +} \ No newline at end of file diff --git a/miniprogram-example/app.wxss b/miniprogram-example/app.wxss new file mode 100644 index 0000000..fb73b4a --- /dev/null +++ b/miniprogram-example/app.wxss @@ -0,0 +1,238 @@ +/**app.wxss**/ +page { + --primary-color: #1890ff; + --danger-color: #ff4d4f; + --success-color: #52c41a; + --warning-color: #faad14; + --text-color: #333333; + --text-color-secondary: #666666; + --text-color-light: #999999; + --border-color: #e8e8e8; + --background-color: #f8f8f8; + --border-radius: 8rpx; + --safe-area-bottom: env(safe-area-inset-bottom); + + font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica, + Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei', + sans-serif; + font-size: 28rpx; + line-height: 1.5; + color: var(--text-color); + background-color: var(--background-color); +} + +/* 清除默认样式 */ +button { + padding: 0; + margin: 0; + background: none; + border: none; + text-align: left; + line-height: inherit; + overflow: visible; +} + +button::after { + border: none; +} + +/* 通用样式类 */ +.container { + min-height: 100vh; + box-sizing: border-box; +} + +.safe-area-bottom { + padding-bottom: var(--safe-area-bottom); +} + +.flex-center { + display: flex; + align-items: center; + justify-content: center; +} + +.flex-between { + display: flex; + align-items: center; + justify-content: space-between; +} + +.flex-column { + display: flex; + flex-direction: column; +} + +.text-primary { + color: var(--primary-color); +} + +.text-danger { + color: var(--danger-color); +} + +.text-success { + color: var(--success-color); +} + +.text-warning { + color: var(--warning-color); +} + +.text-secondary { + color: var(--text-color-secondary); +} + +.text-light { + color: var(--text-color-light); +} + +.text-center { + text-align: center; +} + +.text-left { + text-align: left; +} + +.text-right { + text-align: right; +} + +.text-bold { + font-weight: bold; +} + +.text-ellipsis { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.bg-white { + background-color: #ffffff; +} + +.bg-primary { + background-color: var(--primary-color); +} + +.bg-danger { + background-color: var(--danger-color); +} + +.bg-success { + background-color: var(--success-color); +} + +.bg-warning { + background-color: var(--warning-color); +} + +.shadow { + box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); +} + +.rounded { + border-radius: var(--border-radius); +} + +.border { + border: 2rpx solid var(--border-color); +} + +.border-top { + border-top: 2rpx solid var(--border-color); +} + +.border-bottom { + border-bottom: 2rpx solid var(--border-color); +} + +/* 动画类 */ +.fade-in { + animation: fadeIn 0.3s ease-in-out; +} + +.fade-out { + animation: fadeOut 0.3s ease-in-out; +} + +@keyframes fadeIn { + from { + opacity: 0; + } + to { + opacity: 1; + } +} + +@keyframes fadeOut { + from { + opacity: 1; + } + to { + opacity: 0; + } +} + +/* 间距类 */ +.m-0 { margin: 0; } +.m-1 { margin: 10rpx; } +.m-2 { margin: 20rpx; } +.m-3 { margin: 30rpx; } +.m-4 { margin: 40rpx; } + +.mt-0 { margin-top: 0; } +.mt-1 { margin-top: 10rpx; } +.mt-2 { margin-top: 20rpx; } +.mt-3 { margin-top: 30rpx; } +.mt-4 { margin-top: 40rpx; } + +.mb-0 { margin-bottom: 0; } +.mb-1 { margin-bottom: 10rpx; } +.mb-2 { margin-bottom: 20rpx; } +.mb-3 { margin-bottom: 30rpx; } +.mb-4 { margin-bottom: 40rpx; } + +.ml-0 { margin-left: 0; } +.ml-1 { margin-left: 10rpx; } +.ml-2 { margin-left: 20rpx; } +.ml-3 { margin-left: 30rpx; } +.ml-4 { margin-left: 40rpx; } + +.mr-0 { margin-right: 0; } +.mr-1 { margin-right: 10rpx; } +.mr-2 { margin-right: 20rpx; } +.mr-3 { margin-right: 30rpx; } +.mr-4 { margin-right: 40rpx; } + +.p-0 { padding: 0; } +.p-1 { padding: 10rpx; } +.p-2 { padding: 20rpx; } +.p-3 { padding: 30rpx; } +.p-4 { padding: 40rpx; } + +.pt-0 { padding-top: 0; } +.pt-1 { padding-top: 10rpx; } +.pt-2 { padding-top: 20rpx; } +.pt-3 { padding-top: 30rpx; } +.pt-4 { padding-top: 40rpx; } + +.pb-0 { padding-bottom: 0; } +.pb-1 { padding-bottom: 10rpx; } +.pb-2 { padding-bottom: 20rpx; } +.pb-3 { padding-bottom: 30rpx; } +.pb-4 { padding-bottom: 40rpx; } + +.pl-0 { padding-left: 0; } +.pl-1 { padding-left: 10rpx; } +.pl-2 { padding-left: 20rpx; } +.pl-3 { padding-left: 30rpx; } +.pl-4 { padding-left: 40rpx; } + +.pr-0 { padding-right: 0; } +.pr-1 { padding-right: 10rpx; } +.pr-2 { padding-right: 20rpx; } +.pr-3 { padding-right: 30rpx; } +.pr-4 { padding-right: 40rpx; } \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.js b/miniprogram-example/pages/login/login.js new file mode 100644 index 0000000..e4ad173 --- /dev/null +++ b/miniprogram-example/pages/login/login.js @@ -0,0 +1,48 @@ +// login.js +import { wxLogin } from '../../services/auth'; + +Page({ + data: { + loading: false + }, + + onLoad() { + // 页面加载时检查是否已经登录 + const token = wx.getStorageSync('token'); + if (token) { + this.redirectToHome(); + } + }, + + // 处理登录按钮点击 + handleLogin() { + if (this.data.loading) return; + + this.setData({ loading: true }); + + wxLogin() + .then(() => { + wx.showToast({ + title: '登录成功', + icon: 'success' + }); + this.redirectToHome(); + }) + .catch(err => { + wx.showToast({ + title: err.message || '登录失败', + icon: 'none' + }); + }) + .finally(() => { + this.setData({ loading: false }); + }); + }, + + // 跳转到首页 + redirectToHome() { + wx.reLaunch({ + url: '/pages/otp-list/index' + }); + } +}); \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.json b/miniprogram-example/pages/login/login.json new file mode 100644 index 0000000..8835af0 --- /dev/null +++ b/miniprogram-example/pages/login/login.json @@ -0,0 +1,3 @@ +{ + "usingComponents": {} +} \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.wxml b/miniprogram-example/pages/login/login.wxml new file mode 100644 index 0000000..d73a711 --- /dev/null +++ b/miniprogram-example/pages/login/login.wxml @@ -0,0 +1,30 @@ + + + + + OTPM 小程序 + + + + \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.wxss b/miniprogram-example/pages/login/login.wxss new file mode 100644 index 0000000..26cece6 --- /dev/null +++ b/miniprogram-example/pages/login/login.wxss @@ -0,0 +1,97 @@ +/* login.wxss */ +.container { + display: flex; + flex-direction: column; + align-items: center; + justify-content: space-between; + height: 100vh; + padding: 60rpx 40rpx; + box-sizing: border-box; + background-color: #f8f8f8; +} + +.logo-container { + display: flex; + flex-direction: column; + align-items: center; + margin-top: 80rpx; +} + +.logo { + width: 180rpx; + height: 180rpx; + margin-bottom: 20rpx; +} + +.app-name { + font-size: 36rpx; + font-weight: bold; + color: #333; +} + +.login-container { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; + margin-bottom: 100rpx; +} + +.login-title { + font-size: 48rpx; + font-weight: bold; + color: #333; + margin-bottom: 20rpx; +} + +.login-subtitle { + font-size: 28rpx; + color: #666; + margin-bottom: 80rpx; +} + +.login-button { + width: 80%; + height: 88rpx; + border-radius: 44rpx; + font-size: 32rpx; + display: flex; + align-items: center; + justify-content: center; +} + +.login-button.loading { + background-color: #8cc4ff; +} + +.loading-container { + display: flex; + align-items: center; + justify-content: center; +} + +.loading-icon { + width: 36rpx; + height: 36rpx; + margin-right: 10rpx; + border: 4rpx solid #ffffff; + border-radius: 50%; + border-top-color: transparent; + animation: spin 1s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +.privacy-policy { + margin-top: 40rpx; + font-size: 24rpx; + color: #999; +} + +.policy-link { + color: #1890ff; + display: inline; +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.js b/miniprogram-example/pages/otp-add/index.js new file mode 100644 index 0000000..e89c623 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.js @@ -0,0 +1,169 @@ +// otp-add/index.js +import { createOTP } from '../../services/otp'; + +Page({ + data: { + form: { + name: '', + issuer: '', + secret: '', + algorithm: 'SHA1', + digits: 6, + period: 30 + }, + algorithms: ['SHA1', 'SHA256', 'SHA512'], + digitOptions: [6, 8], + periodOptions: [30, 60], + submitting: false, + scanMode: false + }, + + // 处理输入变化 + handleInputChange(e) { + const { field } = e.currentTarget.dataset; + const { value } = e.detail; + + this.setData({ + [`form.${field}`]: value + }); + }, + + // 处理选择器变化 + handlePickerChange(e) { + const { field } = e.currentTarget.dataset; + const { value } = e.detail; + + const options = this.data[`${field}Options`] || this.data[field]; + const selectedValue = options[value]; + + this.setData({ + [`form.${field}`]: selectedValue + }); + }, + + // 扫描二维码 + handleScanQRCode() { + this.setData({ scanMode: true }); + + wx.scanCode({ + scanType: ['qrCode'], + success: (res) => { + try { + // 解析otpauth://协议的URL + const url = res.result; + if (url.startsWith('otpauth://totp/')) { + const parsedUrl = new URL(url); + const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠 + + // 解析路径中的issuer和name + let issuer = ''; + let name = path; + + if (path.includes(':')) { + const parts = path.split(':'); + issuer = parts[0]; + name = parts[1]; + } + + // 从查询参数中获取其他信息 + const secret = parsedUrl.searchParams.get('secret') || ''; + const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1'; + const digits = parseInt(parsedUrl.searchParams.get('digits') || '6'); + const period = parseInt(parsedUrl.searchParams.get('period') || '30'); + + // 如果查询参数中有issuer,优先使用 + if (parsedUrl.searchParams.get('issuer')) { + issuer = parsedUrl.searchParams.get('issuer'); + } + + this.setData({ + form: { + name, + issuer, + secret, + algorithm, + digits, + period + } + }); + + wx.showToast({ + title: '二维码解析成功', + icon: 'success' + }); + } else { + wx.showToast({ + title: '不支持的二维码格式', + icon: 'none' + }); + } + } catch (err) { + wx.showToast({ + title: '二维码解析失败', + icon: 'none' + }); + } + }, + fail: () => { + wx.showToast({ + title: '扫描取消', + icon: 'none' + }); + }, + complete: () => { + this.setData({ scanMode: false }); + } + }); + }, + + // 提交表单 + handleSubmit() { + const { form } = this.data; + + // 表单验证 + if (!form.name) { + wx.showToast({ + title: '请输入名称', + icon: 'none' + }); + return; + } + + if (!form.secret) { + wx.showToast({ + title: '请输入密钥', + icon: 'none' + }); + return; + } + + this.setData({ submitting: true }); + + createOTP(form) + .then(() => { + wx.showToast({ + title: '添加成功', + icon: 'success' + }); + + // 返回上一页 + setTimeout(() => { + wx.navigateBack(); + }, 1500); + }) + .catch(err => { + wx.showToast({ + title: err.message || '添加失败', + icon: 'none' + }); + }) + .finally(() => { + this.setData({ submitting: false }); + }); + }, + + // 取消 + handleCancel() { + wx.navigateBack(); + } +}); \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.json b/miniprogram-example/pages/otp-add/index.json new file mode 100644 index 0000000..8835af0 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.json @@ -0,0 +1,3 @@ +{ + "usingComponents": {} +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.wxml b/miniprogram-example/pages/otp-add/index.wxml new file mode 100644 index 0000000..e5a2b01 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.wxml @@ -0,0 +1,119 @@ + + + + 添加OTP + + + + + 名称 * + + + + + 发行方 + + + + + 密钥 * + + + + 🔍 + + + + + + + + + 算法 + + + {{form.algorithm}} + + + + + + + + 位数 + + + {{form.digits}} + + + + + + + 周期(秒) + + + {{form.period}} + + + + + + + + + + + + + \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.wxss b/miniprogram-example/pages/otp-add/index.wxss new file mode 100644 index 0000000..3906a97 --- /dev/null +++ b/miniprogram-example/pages/otp-add/index.wxss @@ -0,0 +1,176 @@ +/* otp-add/index.wxss */ +.container { + min-height: 100vh; + background-color: #f8f8f8; + padding: 0 0 40rpx 0; +} + +.header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 40rpx 32rpx; + background-color: #ffffff; + position: sticky; + top: 0; + z-index: 100; + box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05); +} + +.title { + font-size: 36rpx; + font-weight: bold; + color: #333333; +} + +.form-container { + background-color: #ffffff; + padding: 32rpx; + margin: 32rpx; + border-radius: 16rpx; + box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); +} + +.form-group { + margin-bottom: 32rpx; +} + +.form-row { + display: flex; + justify-content: space-between; +} + +.form-group.half { + width: 48%; +} + +.form-label { + display: block; + font-size: 28rpx; + color: #666666; + margin-bottom: 12rpx; +} + +.required { + color: #ff4d4f; +} + +.form-input { + width: 100%; + height: 80rpx; + background-color: #f5f5f5; + border-radius: 8rpx; + padding: 0 24rpx; + font-size: 28rpx; + color: #333333; + box-sizing: border-box; +} + +.secret-input-container { + position: relative; +} + +.scan-button { + position: absolute; + right: 20rpx; + top: 50%; + transform: translateY(-50%); + width: 60rpx; + height: 60rpx; + display: flex; + align-items: center; + justify-content: center; +} + +.scan-icon { + font-size: 40rpx; + color: #1890ff; +} + +.scanning-indicator { + position: absolute; + right: 20rpx; + top: 50%; + transform: translateY(-50%); + width: 40rpx; + height: 40rpx; +} + +.scanning-spinner { + width: 40rpx; + height: 40rpx; + border: 4rpx solid #f3f3f3; + border-top: 4rpx solid #1890ff; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +.picker-view { + width: 100%; + height: 80rpx; + background-color: #f5f5f5; + border-radius: 8rpx; + padding: 0 24rpx; + font-size: 28rpx; + color: #333333; + display: flex; + align-items: center; + justify-content: space-between; + box-sizing: border-box; +} + +.picker-arrow { + font-size: 24rpx; + color: #999999; +} + +.button-group { + display: flex; + justify-content: space-between; + padding: 32rpx; +} + +.cancel-button, .submit-button { + width: 48%; + height: 88rpx; + border-radius: 44rpx; + font-size: 32rpx; + display: flex; + align-items: center; + justify-content: center; +} + +.cancel-button { + background-color: #f5f5f5; + color: #666666; +} + +.submit-button { + background-color: #1890ff; + color: #ffffff; +} + +.submit-button.loading { + background-color: #8cc4ff; +} + +.loading-container { + display: flex; + align-items: center; + justify-content: center; +} + +.loading-icon { + width: 36rpx; + height: 36rpx; + margin-right: 10rpx; + border: 4rpx solid #ffffff; + border-radius: 50%; + border-top-color: transparent; + animation: spin 1s linear infinite; +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.js b/miniprogram-example/pages/otp-list/index.js new file mode 100644 index 0000000..9ed6fa1 --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.js @@ -0,0 +1,213 @@ +// otp-list/index.js +import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp'; +import { checkLoginStatus } from '../../services/auth'; + +Page({ + data: { + otpList: [], + loading: true, + refreshing: false + }, + + onLoad() { + this.checkLogin(); + }, + + onShow() { + // 每次页面显示时刷新OTP列表 + if (!this.data.loading) { + this.fetchOTPList(); + } + }, + + // 下拉刷新 + onPullDownRefresh() { + this.setData({ refreshing: true }); + this.fetchOTPList().finally(() => { + wx.stopPullDownRefresh(); + this.setData({ refreshing: false }); + }); + }, + + // 检查登录状态 + checkLogin() { + checkLoginStatus().then(isLoggedIn => { + if (isLoggedIn) { + this.fetchOTPList(); + } else { + wx.redirectTo({ + url: '/pages/login/login' + }); + } + }); + }, + + // 获取OTP列表 + fetchOTPList() { + this.setData({ loading: true }); + return getOTPList() + .then(res => { + if (res.data && Array.isArray(res.data)) { + this.setData({ + otpList: res.data, + loading: false + }); + + // 获取每个OTP的当前验证码 + this.refreshOTPCodes(); + } + }) + .catch(err => { + wx.showToast({ + title: '获取OTP列表失败', + icon: 'none' + }); + this.setData({ loading: false }); + }); + }, + + // 刷新所有OTP的验证码 + refreshOTPCodes() { + const { otpList } = this.data; + + // 为每个OTP获取当前验证码 + const promises = otpList.map(otp => { + return getOTPCode(otp.id) + .then(res => { + if (res.data && res.data.code) { + return { + id: otp.id, + code: res.data.code, + expiresIn: res.data.expires_in || 30 + }; + } + return null; + }) + .catch(() => null); + }); + + Promise.all(promises).then(results => { + const updatedList = [...this.data.otpList]; + + results.forEach(result => { + if (result) { + const index = updatedList.findIndex(otp => otp.id === result.id); + if (index !== -1) { + updatedList[index] = { + ...updatedList[index], + currentCode: result.code, + expiresIn: result.expiresIn + }; + } + } + }); + + this.setData({ otpList: updatedList }); + + // 设置定时器,每秒更新倒计时 + this.startCountdown(); + }); + }, + + // 开始倒计时 + startCountdown() { + // 清除之前的定时器 + if (this.countdownTimer) { + clearInterval(this.countdownTimer); + } + + // 创建新的定时器,每秒更新一次 + this.countdownTimer = setInterval(() => { + const { otpList } = this.data; + let needRefresh = false; + + const updatedList = otpList.map(otp => { + if (!otp.countdown) { + otp.countdown = otp.expiresIn || 30; + } + + otp.countdown -= 1; + + // 如果倒计时结束,标记需要刷新 + if (otp.countdown <= 0) { + needRefresh = true; + } + + return otp; + }); + + this.setData({ otpList: updatedList }); + + // 如果有OTP需要刷新,重新获取验证码 + if (needRefresh) { + this.refreshOTPCodes(); + } + }, 1000); + }, + + // 添加新的OTP + handleAddOTP() { + wx.navigateTo({ + url: '/pages/otp-add/index' + }); + }, + + // 编辑OTP + handleEditOTP(e) { + const { id } = e.currentTarget.dataset; + wx.navigateTo({ + url: `/pages/otp-edit/index?id=${id}` + }); + }, + + // 删除OTP + handleDeleteOTP(e) { + const { id, name } = e.currentTarget.dataset; + + wx.showModal({ + title: '确认删除', + content: `确定要删除 ${name} 吗?`, + confirmColor: '#ff4d4f', + success: (res) => { + if (res.confirm) { + deleteOTP(id) + .then(() => { + wx.showToast({ + title: '删除成功', + icon: 'success' + }); + this.fetchOTPList(); + }) + .catch(err => { + wx.showToast({ + title: '删除失败', + icon: 'none' + }); + }); + } + } + }); + }, + + // 复制验证码 + handleCopyCode(e) { + const { code } = e.currentTarget.dataset; + + wx.setClipboardData({ + data: code, + success: () => { + wx.showToast({ + title: '验证码已复制', + icon: 'success' + }); + } + }); + }, + + onUnload() { + // 页面卸载时清除定时器 + if (this.countdownTimer) { + clearInterval(this.countdownTimer); + } + } +}); \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.json b/miniprogram-example/pages/otp-list/index.json new file mode 100644 index 0000000..8835af0 --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.json @@ -0,0 +1,3 @@ +{ + "usingComponents": {} +} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.wxml b/miniprogram-example/pages/otp-list/index.wxml new file mode 100644 index 0000000..27cb5ea --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.wxml @@ -0,0 +1,59 @@ + + + + 我的OTP列表 + + + + + + + + + + 加载中... + + + + + + + + + {{item.name}} + {{item.issuer}} + + + + {{item.currentCode || '******'}} + 点击复制 + + + + + {{item.countdown || 0}}s + + + + + + + + + + + + + + + + + + 暂无OTP,点击右上角添加 + + + \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.wxss b/miniprogram-example/pages/otp-list/index.wxss new file mode 100644 index 0000000..40436d5 --- /dev/null +++ b/miniprogram-example/pages/otp-list/index.wxss @@ -0,0 +1,201 @@ +/* otp-list/index.wxss */ +.container { + min-height: 100vh; + background-color: #f8f8f8; + padding: 0 0 40rpx 0; +} + +.header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 40rpx 32rpx; + background-color: #ffffff; + position: sticky; + top: 0; + z-index: 100; + box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05); +} + +.title { + font-size: 36rpx; + font-weight: bold; + color: #333333; +} + +.add-button { + width: 64rpx; + height: 64rpx; + border-radius: 32rpx; + background-color: #1890ff; + display: flex; + align-items: center; + justify-content: center; + box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3); +} + +.add-icon { + color: #ffffff; + font-size: 40rpx; + line-height: 1; +} + +/* 加载状态 */ +.loading-container { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 120rpx 0; +} + +.loading-spinner { + width: 64rpx; + height: 64rpx; + border: 6rpx solid #f3f3f3; + border-top: 6rpx solid #1890ff; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +.loading-text { + margin-top: 20rpx; + font-size: 28rpx; + color: #999999; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* OTP列表 */ +.otp-list { + padding: 20rpx 32rpx; +} + +.otp-item { + background-color: #ffffff; + border-radius: 16rpx; + padding: 32rpx; + margin-bottom: 20rpx; + display: flex; + justify-content: space-between; + box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); +} + +.otp-info { + flex: 1; + margin-right: 20rpx; +} + +.otp-name-row { + display: flex; + align-items: center; + margin-bottom: 16rpx; +} + +.otp-name { + font-size: 32rpx; + font-weight: bold; + color: #333333; + margin-right: 16rpx; +} + +.otp-issuer { + font-size: 24rpx; + color: #666666; + background-color: #f5f5f5; + padding: 4rpx 12rpx; + border-radius: 8rpx; +} + +.otp-code-row { + display: flex; + align-items: center; + margin-bottom: 20rpx; +} + +.otp-code { + font-size: 44rpx; + font-family: monospace; + font-weight: bold; + color: #1890ff; + letter-spacing: 4rpx; + margin-right: 16rpx; +} + +.copy-hint { + font-size: 24rpx; + color: #999999; +} + +.otp-countdown { + position: relative; + width: 100%; +} + +.countdown-text { + position: absolute; + right: 0; + top: -30rpx; + font-size: 24rpx; + color: #999999; +} + +/* OTP操作按钮 */ +.otp-actions { + display: flex; + flex-direction: column; + justify-content: space-between; +} + +.action-button { + width: 56rpx; + height: 56rpx; + border-radius: 28rpx; + display: flex; + align-items: center; + justify-content: center; + margin-bottom: 16rpx; +} + +.action-button.edit { + background-color: #f0f7ff; +} + +.action-button.delete { + background-color: #fff1f0; +} + +.action-icon { + font-size: 32rpx; +} + +.edit .action-icon { + color: #1890ff; +} + +.delete .action-icon { + color: #ff4d4f; +} + +/* 空状态 */ +.empty-state { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 120rpx 0; +} + +.empty-image { + width: 240rpx; + height: 240rpx; + margin-bottom: 32rpx; +} + +.empty-text { + font-size: 28rpx; + color: #999999; +} \ No newline at end of file diff --git a/miniprogram-example/project.config.json b/miniprogram-example/project.config.json new file mode 100644 index 0000000..3fb79ad --- /dev/null +++ b/miniprogram-example/project.config.json @@ -0,0 +1,47 @@ +{ + "description": "项目配置文件", + "packOptions": { + "ignore": [], + "include": [] + }, + "miniprogramRoot": "", + "compileType": "miniprogram", + "projectname": "OTPM", + "setting": { + "useCompilerPlugins": [ + "sass" + ], + "babelSetting": { + "ignore": [], + "disablePlugins": [], + "outputPath": "" + }, + "es6": true, + "enhance": true, + "minified": true, + "postcss": true, + "minifyWXSS": true, + "minifyWXML": true, + "uglifyFileName": true, + "packNpmManually": false, + "packNpmRelationList": [], + "ignoreUploadUnusedFiles": true, + "compileWorklet": false, + "uploadWithSourceMap": true, + "localPlugins": false, + "disableUseStrict": false, + "condition": false, + "swc": false, + "disableSWC": true + }, + "simulatorType": "wechat", + "simulatorPluginLibVersion": {}, + "condition": {}, + "srcMiniprogramRoot": "", + "appid": "wxb6599459668b6b55", + "libVersion": "2.30.2", + "editorSetting": { + "tabIndent": "insertSpaces", + "tabSize": 2 + } +} \ No newline at end of file diff --git a/miniprogram-example/project.private.config.json b/miniprogram-example/project.private.config.json new file mode 100644 index 0000000..6b2a738 --- /dev/null +++ b/miniprogram-example/project.private.config.json @@ -0,0 +1,23 @@ +{ + "libVersion": "3.8.5", + "projectname": "OTPM", + "condition": {}, + "setting": { + "urlCheck": true, + "coverView": false, + "lazyloadPlaceholderEnable": false, + "skylineRenderEnable": false, + "preloadBackgroundData": false, + "autoAudits": false, + "useApiHook": true, + "useApiHostProcess": true, + "showShadowRootInWxmlPanel": false, + "useStaticServer": false, + "useLanDebug": false, + "showES6CompileOption": false, + "compileHotReLoad": true, + "checkInvalidKey": true, + "ignoreDevUnusedFiles": true, + "bigPackageSizeSupport": false + } +} \ No newline at end of file diff --git a/miniprogram-example/services/auth.js b/miniprogram-example/services/auth.js new file mode 100644 index 0000000..4c95fa4 --- /dev/null +++ b/miniprogram-example/services/auth.js @@ -0,0 +1,84 @@ +// auth.js - 认证相关服务 + +import request from '../utils/request'; + +/** + * 微信登录 + * 1. 调用wx.login获取code + * 2. 发送code到服务端换取token和openid + * 3. 保存token和openid到本地存储 + */ +export const wxLogin = () => { + return new Promise((resolve, reject) => { + wx.login({ + success: (res) => { + if (res.code) { + // 发送code到服务端 + request({ + url: '/login', + method: 'POST', + data: { + code: res.code + } + }).then(response => { + // 保存token和openid + if (response.data && response.data.token && response.data.openid) { + wx.setStorageSync('token', response.data.token); + wx.setStorageSync('openid', response.data.openid); + resolve(response.data); + } else { + reject(new Error('登录失败,服务器返回数据格式错误')); + } + }).catch(err => { + reject(err); + }); + } else { + reject(new Error('登录失败,获取code失败: ' + res.errMsg)); + } + }, + fail: (err) => { + reject(new Error('微信登录失败: ' + err.errMsg)); + } + }); + }); +}; + +/** + * 检查登录状态 + * 1. 检查本地是否有token和openid + * 2. 如果有,验证token是否有效 + * 3. 如果无效,清除本地存储并返回false + */ +export const checkLoginStatus = () => { + return new Promise((resolve, reject) => { + const token = wx.getStorageSync('token'); + const openid = wx.getStorageSync('openid'); + + if (!token || !openid) { + resolve(false); + return; + } + + // 验证token有效性 + request({ + url: '/verify-token', + method: 'POST' + }).then(() => { + resolve(true); + }).catch(() => { + // token无效,清除本地存储 + wx.removeStorageSync('token'); + wx.removeStorageSync('openid'); + resolve(false); + }); + }); +}; + +/** + * 退出登录 + */ +export const logout = () => { + wx.removeStorageSync('token'); + wx.removeStorageSync('openid'); + return Promise.resolve(); +}; \ No newline at end of file diff --git a/miniprogram-example/services/otp.js b/miniprogram-example/services/otp.js new file mode 100644 index 0000000..9da9925 --- /dev/null +++ b/miniprogram-example/services/otp.js @@ -0,0 +1,119 @@ +// otp.js - OTP相关服务 + +import request from '../utils/request'; + +/** + * 创建新的OTP + * @param {Object} params - 创建OTP的参数 + * @param {string} params.name - OTP名称 + * @param {string} params.issuer - 发行方 + * @param {string} params.secret - 密钥 + * @param {string} params.algorithm - 算法,默认为SHA1 + * @param {number} params.digits - 位数,默认为6 + * @param {number} params.period - 周期,默认为30秒 + * @returns {Promise} - 返回创建结果 + */ +export const createOTP = (params) => { + if (!params || !params.secret) { + return Promise.reject(new Error('缺少必要的参数: secret')); + } + return request({ + url: '/otp', + method: 'POST', + data: { + name: params.name || '', + issuer: params.issuer || '', + secret: params.secret, + algorithm: params.algorithm || 'SHA1', + digits: params.digits || 6, + period: params.period || 30 + } + }).catch(err => { + console.error('创建OTP失败:', err); + throw new Error('创建OTP失败: ' + (err.message || '未知错误')); + }); +}; + +/** + * 获取用户所有OTP列表 + * @returns {Promise} - 返回OTP列表 + */ +export const getOTPList = () => { + return request({ + url: '/otp', + method: 'GET' + }).catch(err => { + console.error('获取OTP列表失败:', err); + throw new Error('获取OTP列表失败: ' + (err.message || '未知错误')); + }); +}; + +/** + * 获取指定OTP的当前验证码 + * @param {string} id - OTP的ID + * @returns {Promise} - 返回当前验证码 + */ +export const getOTPCode = (id) => { + if (!id) { + return Promise.reject(new Error('缺少必要的参数: id')); + } + return request({ + url: `/otp/${id}/code`, + method: 'GET' + }).catch(err => { + console.error('获取OTP代码失败:', err); + throw new Error('获取OTP代码失败: ' + (err.message || '未知错误')); + }); +}; + +/** + * 验证OTP + * @param {string} id - OTP的ID + * @param {string} code - 用户输入的验证码 + * @returns {Promise} - 返回验证结果 + */ +export const verifyOTP = (id, code) => { + if (!id || !code) { + return Promise.reject(new Error('缺少必要的参数: id或code')); + } + return request({ + url: `/otp/${id}/verify`, + method: 'POST', + data: { code } + }).catch(err => { + console.error('验证OTP失败:', err); + throw new Error('验证OTP失败: ' + (err.message || '未知错误')); + }); +}; + +/** + * 更新OTP信息 + * @param {string} id - OTP的ID + * @param {Object} params - 更新的参数 + * @returns {Promise} - 返回更新结果 + */ +export const updateOTP = (id, params) => { + if (!id || !params) { + return Promise.reject(new Error('缺少必要的参数: id或params')); + } + return request({ + url: `/otp/${id}`, + method: 'PUT', + data: params + }).catch(err => { + console.error('更新OTP失败:', err); + throw new Error('更新OTP失败: ' + (err.message || '未知错误')); + }); +}; + +/** + * 删除OTP + * @param {string} id - OTP的ID + * @returns {Promise} - 返回删除结果 + */ +export const deleteOTP = (id) => { + return request({ + url: `/otp/${id}`, + method: 'DELETE' + }); +}; \ No newline at end of file diff --git a/miniprogram-example/utils/request.js b/miniprogram-example/utils/request.js new file mode 100644 index 0000000..9ea7e2e --- /dev/null +++ b/miniprogram-example/utils/request.js @@ -0,0 +1,58 @@ +// request.js - 网络请求工具类 + +const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名 + +// 请求拦截器 +const request = (options) => { + return new Promise((resolve, reject) => { + const token = wx.getStorageSync('token'); + const header = { + 'Content-Type': 'application/json', + ...options.header + }; + + // 如果有token,添加到请求头 + if (token) { + header['Authorization'] = `Bearer ${token}`; + } + + wx.request({ + url: `${BASE_URL}${options.url}`, + method: options.method || 'GET', + data: options.data, + header: header, + success: (res) => { + // 处理业务错误 + if (res.data.code !== 0) { + // token过期,直接清除并跳转登录 + if (res.statusCode === 401) { + wx.removeStorageSync('token'); + wx.removeStorageSync('openid'); + reject(new Error('登录已过期,请重新登录')); + return; + } + reject(new Error(res.data.message || '请求失败')); + return; + } + resolve(res.data); + }, + fail: reject + }); + }); +}; + +// 刷新token +const refreshToken = () => { + return request({ + url: '/refresh-token', + method: 'POST' + }).then(res => { + if (res.data && res.data.token) { + wx.setStorageSync('token', res.data.token); + return res.data.token; + } + throw new Error('Failed to refresh token'); + }); +}; + +export default request; \ No newline at end of file diff --git a/models/otp.go b/models/otp.go new file mode 100644 index 0000000..8eaab88 --- /dev/null +++ b/models/otp.go @@ -0,0 +1,66 @@ +package models + +import ( + "context" + "time" +) + +// OTP represents a TOTP configuration +type OTP struct { + ID int64 `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id" validate:"required"` + OpenID string `json:"openid" db:"openid" validate:"required"` + Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"` + Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"` + Secret string `json:"secret" db:"secret" validate:"required,otpsecret"` + Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"` + Period int `json:"period" db:"period" validate:"required,min=30,max=60"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +// OTPParams represents common OTP parameters used in creation and update +type OTPParams struct { + Name string `json:"name" validate:"required,min=1,max=100,no_xss"` + Issuer string `json:"issuer" validate:"omitempty,issuer"` + Secret string `json:"secret" validate:"required,otpsecret"` + Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" validate:"omitempty,min=6,max=8"` + Period int `json:"period" validate:"omitempty,min=30,max=60"` +} + +// OTPRepository handles OTP data storage +type OTPRepository struct { + // Add your database connection or ORM here +} + +// Create creates a new OTP record +func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error { + // Implement database creation logic + return nil +} + +// FindByID finds an OTP by ID and user ID +func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) { + // Implement database lookup logic + return nil, nil +} + +// FindAllByUserID finds all OTPs for a user +func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) { + // Implement database query logic + return nil, nil +} + +// Update updates an existing OTP record +func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error { + // Implement database update logic + return nil +} + +// Delete deletes an OTP record +func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error { + // Implement database deletion logic + return nil +} diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..f12399b --- /dev/null +++ b/models/user.go @@ -0,0 +1,114 @@ +package models + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/jmoiron/sqlx" +) + +// User represents a user in the system +type User struct { + ID string `db:"id" json:"id"` + OpenID string `db:"openid" json:"openid"` + SessionKey string `db:"session_key" json:"-"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +// UserRepository handles user data operations +type UserRepository struct { + db *sqlx.DB +} + +// NewUserRepository creates a new UserRepository +func NewUserRepository(db *sqlx.DB) *UserRepository { + return &UserRepository{db: db} +} + +// FindByID finds a user by ID +func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) { + var user User + query := `SELECT * FROM users WHERE id = ?` + err := r.db.GetContext(ctx, &user, query, id) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("user not found: %w", err) + } + return nil, fmt.Errorf("failed to find user: %w", err) + } + return &user, nil +} + +// FindByOpenID finds a user by OpenID +func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) { + var user User + query := `SELECT * FROM users WHERE openid = ?` + err := r.db.GetContext(ctx, &user, query, openID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil // User not found, but not an error + } + return nil, fmt.Errorf("failed to find user: %w", err) + } + return &user, nil +} + +// Create creates a new user +func (r *UserRepository) Create(ctx context.Context, user *User) error { + query := ` + INSERT INTO users (id, openid, session_key, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + ` + now := time.Now() + user.CreatedAt = now + user.UpdatedAt = now + + _, err := r.db.ExecContext( + ctx, + query, + user.ID, + user.OpenID, + user.SessionKey, + user.CreatedAt, + user.UpdatedAt, + ) + if err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + return nil +} + +// Update updates an existing user +func (r *UserRepository) Update(ctx context.Context, user *User) error { + query := ` + UPDATE users + SET session_key = ?, updated_at = ? + WHERE id = ? + ` + user.UpdatedAt = time.Now() + + _, err := r.db.ExecContext( + ctx, + query, + user.SessionKey, + user.UpdatedAt, + user.ID, + ) + if err != nil { + return fmt.Errorf("failed to update user: %w", err) + } + return nil +} + +// Delete deletes a user +func (r *UserRepository) Delete(ctx context.Context, id string) error { + query := `DELETE FROM users WHERE id = ?` + _, err := r.db.ExecContext(ctx, query, id) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + return nil +} diff --git a/security/security.go b/security/security.go new file mode 100644 index 0000000..7b69b81 --- /dev/null +++ b/security/security.go @@ -0,0 +1,332 @@ +package security + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "net/http" + "strings" + "time" + + "golang.org/x/crypto/argon2" +) + +// SecurityService provides security functionality +type SecurityService struct { + config *Config +} + +// Config represents security configuration +type Config struct { + // CSRF protection + CSRFTokenLength int + CSRFTokenExpiry time.Duration + CSRFCookieName string + CSRFHeaderName string + CSRFCookieSecure bool + CSRFCookieHTTPOnly bool + CSRFCookieSameSite http.SameSite + + // Rate limiting + RateLimitRequests int + RateLimitWindow time.Duration + + // Password hashing + Argon2Time uint32 + Argon2Memory uint32 + Argon2Threads uint8 + Argon2KeyLen uint32 + Argon2SaltLen uint32 +} + +// DefaultConfig returns the default security configuration +func DefaultConfig() *Config { + return &Config{ + // CSRF protection + CSRFTokenLength: 32, + CSRFTokenExpiry: 24 * time.Hour, + CSRFCookieName: "csrf_token", + CSRFHeaderName: "X-CSRF-Token", + CSRFCookieSecure: true, + CSRFCookieHTTPOnly: true, + CSRFCookieSameSite: http.SameSiteStrictMode, + + // Rate limiting + RateLimitRequests: 100, + RateLimitWindow: time.Minute, + + // Password hashing + Argon2Time: 1, + Argon2Memory: 64 * 1024, + Argon2Threads: 4, + Argon2KeyLen: 32, + Argon2SaltLen: 16, + } +} + +// NewSecurityService creates a new SecurityService +func NewSecurityService(config *Config) *SecurityService { + if config == nil { + config = DefaultConfig() + } + return &SecurityService{ + config: config, + } +} + +// GenerateCSRFToken generates a CSRF token +func (s *SecurityService) GenerateCSRFToken() (string, error) { + // Generate random bytes + bytes := make([]byte, s.config.CSRFTokenLength) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Encode as base64 + token := base64.StdEncoding.EncodeToString(bytes) + return token, nil +} + +// SetCSRFCookie sets a CSRF cookie +func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) { + http.SetCookie(w, &http.Cookie{ + Name: s.config.CSRFCookieName, + Value: token, + Path: "/", + Expires: time.Now().Add(s.config.CSRFTokenExpiry), + Secure: s.config.CSRFCookieSecure, + HttpOnly: s.config.CSRFCookieHTTPOnly, + SameSite: s.config.CSRFCookieSameSite, + }) +} + +// ValidateCSRFToken validates a CSRF token +func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool { + // Get token from cookie + cookie, err := r.Cookie(s.config.CSRFCookieName) + if err != nil { + return false + } + cookieToken := cookie.Value + + // Get token from header + headerToken := r.Header.Get(s.config.CSRFHeaderName) + + // Compare tokens + return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1 +} + +// CSRFMiddleware creates a middleware that validates CSRF tokens +func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip for GET, HEAD, OPTIONS, TRACE + if r.Method == http.MethodGet || + r.Method == http.MethodHead || + r.Method == http.MethodOptions || + r.Method == http.MethodTrace { + next.ServeHTTP(w, r) + return + } + + // Validate CSRF token + if !s.ValidateCSRFToken(r) { + http.Error(w, "Invalid CSRF token", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} + +// RateLimiter represents a rate limiter +type RateLimiter struct { + requests map[string][]time.Time + config *Config +} + +// NewRateLimiter creates a new RateLimiter +func NewRateLimiter(config *Config) *RateLimiter { + return &RateLimiter{ + requests: make(map[string][]time.Time), + config: config, + } +} + +// Allow checks if a request is allowed +func (r *RateLimiter) Allow(key string) bool { + now := time.Now() + windowStart := now.Add(-r.config.RateLimitWindow) + + // Get requests for key + requests := r.requests[key] + + // Filter out old requests + var newRequests []time.Time + for _, t := range requests { + if t.After(windowStart) { + newRequests = append(newRequests, t) + } + } + + // Check if rate limit is exceeded + if len(newRequests) >= r.config.RateLimitRequests { + return false + } + + // Add current request + newRequests = append(newRequests, now) + r.requests[key] = newRequests + + return true +} + +// RateLimitMiddleware creates a middleware that limits request rate +func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get client IP + ip := getClientIP(r) + + // Check if request is allowed + if !limiter.Allow(ip) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// getClientIP gets the client IP address +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header + xForwardedFor := r.Header.Get("X-Forwarded-For") + if xForwardedFor != "" { + // X-Forwarded-For can contain multiple IPs, use the first one + ips := strings.Split(xForwardedFor, ",") + return strings.TrimSpace(ips[0]) + } + + // Check X-Real-IP header + xRealIP := r.Header.Get("X-Real-IP") + if xRealIP != "" { + return xRealIP + } + + // Use RemoteAddr + return r.RemoteAddr +} + +// HashPassword hashes a password using Argon2 +func (s *SecurityService) HashPassword(password string) (string, error) { + // Generate salt + salt := make([]byte, s.config.Argon2SaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("failed to generate salt: %w", err) + } + + // Hash password + hash := argon2.IDKey( + []byte(password), + salt, + s.config.Argon2Time, + s.config.Argon2Memory, + s.config.Argon2Threads, + s.config.Argon2KeyLen, + ) + + // Encode as base64 + saltBase64 := base64.StdEncoding.EncodeToString(salt) + hashBase64 := base64.StdEncoding.EncodeToString(hash) + + // Format as $argon2id$v=19$m=65536,t=1,p=4$$ + return fmt.Sprintf( + "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", + s.config.Argon2Memory, + s.config.Argon2Time, + s.config.Argon2Threads, + saltBase64, + hashBase64, + ), nil +} + +// VerifyPassword verifies a password against a hash +func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) { + // Parse encoded hash + parts := strings.Split(encodedHash, "$") + if len(parts) != 6 { + return false, fmt.Errorf("invalid hash format") + } + + // Extract parameters + if parts[1] != "argon2id" { + return false, fmt.Errorf("unsupported hash algorithm") + } + + var memory uint32 + var time uint32 + var threads uint8 + _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) + if err != nil { + return false, fmt.Errorf("failed to parse hash parameters: %w", err) + } + + // Decode salt and hash + salt, err := base64.StdEncoding.DecodeString(parts[4]) + if err != nil { + return false, fmt.Errorf("failed to decode salt: %w", err) + } + + hash, err := base64.StdEncoding.DecodeString(parts[5]) + if err != nil { + return false, fmt.Errorf("failed to decode hash: %w", err) + } + + // Hash password with same parameters + newHash := argon2.IDKey( + []byte(password), + salt, + time, + memory, + threads, + uint32(len(hash)), + ) + + // Compare hashes + return subtle.ConstantTimeCompare(hash, newHash) == 1, nil +} + +// SecureHeadersMiddleware adds security headers to responses +func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add security headers + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-XSS-Protection", "1; mode=block") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + w.Header().Set("Content-Security-Policy", "default-src 'self'") + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + + next.ServeHTTP(w, r) + }) +} + +// contextKey is a type for context keys +type contextKey string + +// userIDKey is the key for user ID in context +const userIDKey = contextKey("user_id") + +// WithUserID adds a user ID to the context +func WithUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, userIDKey, userID) +} + +// GetUserID gets the user ID from the context +func GetUserID(ctx context.Context) (string, bool) { + userID, ok := ctx.Value(userIDKey).(string) + return userID, ok +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..a10e990 --- /dev/null +++ b/server/server.go @@ -0,0 +1,190 @@ +package server + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "otpm/config" + "otpm/middleware" + + "github.com/julienschmidt/httprouter" +) + +// Server represents the HTTP server +type Server struct { + server *http.Server + router *httprouter.Router + config *config.Config +} + +// New creates a new server +func New(cfg *config.Config) *Server { + router := httprouter.New() + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.Server.Port), + Handler: router, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + IdleTimeout: 120 * time.Second, + } + + return &Server{ + server: server, + router: router, + config: cfg, + } +} + +// Start starts the server +func (s *Server) Start() error { + // Apply global middleware in correct order with enhanced error handling + var handler http.Handler = s.router + + // Logger should be first to capture all request details + handler = middleware.Logger(handler) + + // CORS next to handle pre-flight requests + handler = middleware.CORS(handler) + + // Then Timeout to enforce request deadlines + handler = middleware.Timeout(s.config.Server.Timeout)(handler) + + // Recover should be outermost to catch any panics + handler = middleware.Recover(handler) + + s.server.Handler = handler + + // Log server configuration at startup + log.Printf("Server configuration:\n"+ + "Address: %s\n"+ + "Read Timeout: %v\n"+ + "Write Timeout: %v\n"+ + "Idle Timeout: %v\n"+ + "Request Timeout: %v", + s.server.Addr, + s.server.ReadTimeout, + s.server.WriteTimeout, + s.server.IdleTimeout, + s.config.Server.Timeout, + ) + + // Start server in a goroutine + serverErr := make(chan error, 1) + go func() { + log.Printf("Server starting on %s", s.server.Addr) + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + serverErr <- fmt.Errorf("server error: %w", err) + } + }() + + // Wait for interrupt signal or server error + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + select { + case err := <-serverErr: + return err + case <-quit: + return s.Shutdown() + } +} + +// Shutdown gracefully stops the server +func (s *Server) Shutdown() error { + log.Println("Shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout) + defer cancel() + + if err := s.server.Shutdown(ctx); err != nil { + return fmt.Errorf("graceful shutdown failed: %w", err) + } + + log.Println("Server stopped gracefully") + return nil +} + +// Router returns the router +func (s *Server) Router() *httprouter.Router { + return s.router +} + +// RegisterRoutes registers all routes +func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) { + for pattern, handler := range routes { + s.router.Handle("GET", pattern, handler) + s.router.Handle("POST", pattern, handler) + s.router.Handle("PUT", pattern, handler) + s.router.Handle("DELETE", pattern, handler) + } +} + +// RegisterAuthRoutes registers routes that require authentication +func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) { + for pattern, handler := range routes { + // Apply authentication middleware + authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + // Convert httprouter.Handle to http.HandlerFunc for middleware + wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Store params in request context + ctx := context.WithValue(r.Context(), "params", ps) + handler(w, r.WithContext(ctx), ps) + }) + + // Apply auth middleware + middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r) + } + + s.router.Handle("GET", pattern, authHandler) + s.router.Handle("POST", pattern, authHandler) + s.router.Handle("PUT", pattern, authHandler) + s.router.Handle("DELETE", pattern, authHandler) + } +} + +// RegisterHealthCheck registers an enhanced health check endpoint +func (s *Server) RegisterHealthCheck() { + s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + response := map[string]interface{}{ + "status": "ok", + "timestamp": time.Now().Format(time.RFC3339), + "version": "1.0.0", // Hardcoded version instead of from config + "system": map[string]interface{}{ + "goroutines": runtime.NumGoroutine(), + "memory": getMemoryUsage(), + }, + } + + // Add database status if configured + if s.config.Database.DSN != "" { + dbStatus := "ok" + response["database"] = dbStatus + } + + middleware.SuccessResponse(w, response) + }) +} + +// getMemoryUsage returns current memory usage in MB +func getMemoryUsage() map[string]interface{} { + var m runtime.MemStats + runtime.ReadMemStats(&m) + return map[string]interface{}{ + "alloc_mb": bToMb(m.Alloc), + "total_alloc_mb": bToMb(m.TotalAlloc), + "sys_mb": bToMb(m.Sys), + "num_gc": m.NumGC, + } +} + +func bToMb(b uint64) float64 { + return float64(b) / 1024 / 1024 +} diff --git a/services/auth.go b/services/auth.go new file mode 100644 index 0000000..9ad2a85 --- /dev/null +++ b/services/auth.go @@ -0,0 +1,230 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + + "otpm/config" + "otpm/models" +) + +// WeChatCode2SessionResponse represents the response from WeChat code2session API +type WeChatCode2SessionResponse struct { + OpenID string `json:"openid"` + SessionKey string `json:"session_key"` + UnionID string `json:"unionid"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +// AuthService handles authentication related operations +type AuthService struct { + config *config.Config + userRepo *models.UserRepository + httpClient *http.Client +} + +// NewAuthService creates a new AuthService +func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService { + return &AuthService{ + config: cfg, + userRepo: userRepo, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// LoginWithWeChatCode handles WeChat login +func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) { + start := time.Now() + + // Get OpenID and SessionKey from WeChat + sessionInfo, err := s.getWeChatSession(code) + if err != nil { + log.Printf("WeChat login failed for code %s: %v", maskCode(code), err) + return "", fmt.Errorf("failed to get WeChat session: %w", err) + } + log.Printf("WeChat session obtained for code %s (took %v)", + maskCode(code), time.Since(start)) + + // Find or create user + user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID) + if err != nil { + log.Printf("User lookup failed for OpenID %s: %v", + maskOpenID(sessionInfo.OpenID), err) + return "", fmt.Errorf("failed to find user: %w", err) + } + + if user == nil { + // Create new user + user = &models.User{ + ID: uuid.New().String(), + OpenID: sessionInfo.OpenID, + SessionKey: sessionInfo.SessionKey, + } + if err := s.userRepo.Create(ctx, user); err != nil { + log.Printf("User creation failed for OpenID %s: %v", + maskOpenID(sessionInfo.OpenID), err) + return "", fmt.Errorf("failed to create user: %w", err) + } + log.Printf("New user created with ID %s for OpenID %s", + user.ID, maskOpenID(sessionInfo.OpenID)) + } else { + // Update session key + user.SessionKey = sessionInfo.SessionKey + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("User update failed for ID %s: %v", user.ID, err) + return "", fmt.Errorf("failed to update user: %w", err) + } + log.Printf("User %s session key updated", user.ID) + } + + // Generate JWT token + token, err := s.generateToken(user) + if err != nil { + log.Printf("Token generation failed for user %s: %v", user.ID, err) + return "", fmt.Errorf("failed to generate token: %w", err) + } + + log.Printf("WeChat login completed for user %s (total time %v)", + user.ID, time.Since(start)) + return token, nil +} + +// maskCode masks sensitive parts of WeChat code for logging +func maskCode(code string) string { + if len(code) < 8 { + return "****" + } + return code[:2] + "****" + code[len(code)-2:] +} + +// maskOpenID masks sensitive parts of OpenID for logging +func maskOpenID(openID string) string { + if len(openID) < 8 { + return "****" + } + return openID[:2] + "****" + openID[len(openID)-2:] +} + +// getWeChatSession calls WeChat's code2session API +func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) { + url := fmt.Sprintf( + "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", + s.config.WeChat.AppID, + s.config.WeChat.AppSecret, + code, + ) + + resp, err := s.httpClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to call WeChat API: %w", err) + } + defer resp.Body.Close() + + var result WeChatCode2SessionResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode WeChat response: %w", err) + } + + if result.ErrCode != 0 { + return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg) + } + + return &result, nil +} + +// generateToken generates a JWT token for a user +func (s *AuthService) generateToken(user *models.User) (string, error) { + now := time.Now() + claims := jwt.MapClaims{ + "user_id": user.ID, + "exp": now.Add(s.config.JWT.ExpireDelta).Unix(), + "iat": now.Unix(), + "iss": s.config.JWT.Issuer, + "aud": s.config.JWT.Audience, + "token_id": uuid.New().String(), // Unique token ID for tracking + } + + // Use stronger signing method + token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) + signedToken, err := token.SignedString([]byte(s.config.JWT.Secret)) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + log.Printf("Token generated for user %s (expires at %v)", + user.ID, now.Add(s.config.JWT.ExpireDelta)) + return signedToken, nil +} + +// ValidateToken validates a JWT token with additional checks +func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(s.config.JWT.Secret), nil + }) + + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + switch { + case ve.Errors&jwt.ValidationErrorMalformed != 0: + return nil, fmt.Errorf("malformed token") + case ve.Errors&jwt.ValidationErrorExpired != 0: + return nil, fmt.Errorf("token expired") + case ve.Errors&jwt.ValidationErrorNotValidYet != 0: + return nil, fmt.Errorf("token not active yet") + default: + return nil, fmt.Errorf("token validation error: %w", err) + } + } + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + // Additional claims validation + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + // Check issuer + if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer { + return nil, fmt.Errorf("invalid token issuer") + } + // Check audience + if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience { + return nil, fmt.Errorf("invalid token audience") + } + } else { + return nil, fmt.Errorf("invalid token claims") + } + + return token, nil +} + +// GetUserFromToken gets user information from a JWT token +func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) { + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid token claims") + } + + userID, ok := claims["user_id"].(string) + if !ok { + return nil, fmt.Errorf("user_id not found in token") + } + + user, err := s.userRepo.FindByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to find user: %w", err) + } + + return user, nil +} diff --git a/services/otp.go b/services/otp.go new file mode 100644 index 0000000..c345bea --- /dev/null +++ b/services/otp.go @@ -0,0 +1,358 @@ +package services + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/base32" + "encoding/binary" + "fmt" + "hash" + "log" + "strings" + "time" + + "otpm/models" + + "github.com/google/uuid" +) + +// OTPService handles OTP related operations +type OTPService struct { + otpRepo *models.OTPRepository +} + +// NewOTPService creates a new OTPService +func NewOTPService(otpRepo *models.OTPRepository) *OTPService { + return &OTPService{ + otpRepo: otpRepo, + } +} + +// CreateOTP creates a new OTP with performance monitoring and logging +func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) { + start := time.Now() + + // Validate input + if err := s.validateOTPInput(input); err != nil { + log.Printf("OTP validation failed for user %s: %v", userID, err) + return nil, err + } + + // Clean and standardize secret + secret := cleanSecret(input.Secret) + + // Set defaults for optional fields + algorithm := strings.ToUpper(input.Algorithm) + if algorithm == "" { + algorithm = "SHA1" + } + + digits := input.Digits + if digits == 0 { + digits = 6 + } + + period := input.Period + if period == 0 { + period = 30 + } + + // Create OTP + otp := &models.OTP{ + ID: uuid.New().String(), + UserID: userID, + Name: input.Name, + Issuer: input.Issuer, + Secret: secret, + Algorithm: algorithm, + Digits: digits, + Period: period, + } + + if err := s.otpRepo.Create(ctx, otp); err != nil { + log.Printf("Failed to create OTP for user %s: %v", userID, err) + return nil, fmt.Errorf("failed to create OTP: %w", err) + } + + // Log successful creation (without exposing secret) + log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)", + otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period) + + return otp, nil +} + +// GetOTP gets an OTP by ID +func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) { + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + return nil, fmt.Errorf("failed to get OTP: %w", err) + } + return otp, nil +} + +// ListOTPs lists all OTPs for a user +func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) { + otps, err := s.otpRepo.FindAllByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to list OTPs: %w", err) + } + return otps, nil +} + +// UpdateOTP updates an OTP +func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) { + // Get existing OTP + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + return nil, fmt.Errorf("failed to get OTP: %w", err) + } + + // Update fields + if input.Name != "" { + otp.Name = input.Name + } + if input.Issuer != "" { + otp.Issuer = input.Issuer + } + if input.Algorithm != "" { + otp.Algorithm = strings.ToUpper(input.Algorithm) + } + if input.Digits > 0 { + otp.Digits = input.Digits + } + if input.Period > 0 { + otp.Period = input.Period + } + + // Validate updated OTP + if err := s.validateOTPInput(models.OTPParams{ + Name: otp.Name, + Issuer: otp.Issuer, + Secret: otp.Secret, + Algorithm: otp.Algorithm, + Digits: otp.Digits, + Period: otp.Period, + }); err != nil { + return nil, err + } + + if err := s.otpRepo.Update(ctx, otp); err != nil { + return nil, fmt.Errorf("failed to update OTP: %w", err) + } + + return otp, nil +} + +// DeleteOTP deletes an OTP +func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error { + if err := s.otpRepo.Delete(ctx, id, userID); err != nil { + return fmt.Errorf("failed to delete OTP: %w", err) + } + return nil +} + +// GenerateCode generates a TOTP code with enhanced logging and error handling +func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) { + start := time.Now() + + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err) + return "", 0, fmt.Errorf("failed to get OTP: %w", err) + } + + // Get current time step + now := time.Now().Unix() + timeStep := now / int64(otp.Period) + + // Generate code + code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits) + if err != nil { + log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err) + return "", 0, fmt.Errorf("failed to generate code: %w", err) + } + + // Calculate remaining seconds + remainingSeconds := otp.Period - int(now%int64(otp.Period)) + + // Log successful generation (without actual code) + log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)", + id, userID, time.Since(start), remainingSeconds) + + return code, remainingSeconds, nil +} + +// VerifyCode verifies a TOTP code with enhanced security and logging +func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) { + start := time.Now() + + // Basic input validation + if len(code) == 0 { + log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID) + return false, fmt.Errorf("code is required") + } + + otp, err := s.otpRepo.FindByID(ctx, id, userID) + if err != nil { + log.Printf("Failed to find OTP %s for user %s during verification: %v", + id, userID, err) + return false, fmt.Errorf("failed to get OTP: %w", err) + } + + // Get current and adjacent time steps + now := time.Now().Unix() + timeSteps := []int64{ + (now - int64(otp.Period)) / int64(otp.Period), + now / int64(otp.Period), + (now + int64(otp.Period)) / int64(otp.Period), + } + + // Check code against all time steps + for _, ts := range timeSteps { + expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits) + if err != nil { + log.Printf("Code generation failed for time step %d: %v", ts, err) + continue + } + if expectedCode == code { + // Log successful verification + log.Printf("Code verified successfully for OTP %s (user %s) in %v", + id, userID, time.Since(start)) + return true, nil + } + } + + // Log failed verification attempt + log.Printf("Invalid code provided for OTP %s (user %s) in %v", + id, userID, time.Since(start)) + + return false, nil +} + +// validateOTPInput validates OTP input with detailed error messages +func (s *OTPService) validateOTPInput(input models.OTPParams) error { + if input.Name == "" { + return fmt.Errorf("name is required") + } + + if len(input.Name) > 100 { + return fmt.Errorf("name is too long (maximum 100 characters)") + } + + if input.Secret == "" { + return fmt.Errorf("secret is required") + } + + if !isValidBase32(input.Secret) { + return fmt.Errorf("invalid secret format: must be a valid base32 string") + } + + // Secret length check (after base32 decoding) + secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "=")) + if len(secretBytes) < 10 { + return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)") + } + + if input.Algorithm != "" { + if !isValidAlgorithm(input.Algorithm) { + return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm) + } + } + + if input.Digits != 0 { + if input.Digits < 6 || input.Digits > 8 { + return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits) + } + } + + if input.Period != 0 { + if input.Period < 30 || input.Period > 60 { + return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period) + } + } + + return nil +} + +// Helper functions + +func cleanSecret(secret string) string { + // Remove spaces and convert to upper case + secret = strings.TrimSpace(strings.ToUpper(secret)) + // Remove any padding characters + return strings.TrimRight(secret, "=") +} + +func isValidBase32(s string) bool { + // Try to decode the secret + _, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "=")) + return err == nil +} + +func isValidAlgorithm(algorithm string) bool { + switch strings.ToUpper(algorithm) { + case "SHA1", "SHA256", "SHA512": + return true + default: + return false + } +} + +func getHasher(algorithm string, key []byte) (hash.Hash, error) { + switch strings.ToUpper(algorithm) { + case "SHA1": + return hmac.New(sha1.New, key), nil + case "SHA256": + return hmac.New(sha256.New, key), nil + case "SHA512": + return hmac.New(sha512.New, key), nil + default: + return nil, fmt.Errorf("unsupported algorithm: %s", algorithm) + } +} + +func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) { + // Decode secret + secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "=")) + if err != nil { + return "", fmt.Errorf("invalid secret: %w", err) + } + + // Get initialized HMAC hasher with secret + hasher, err := getHasher(algorithm, secretBytes) + if err != nil { + return "", err + } + + // Convert time step to bytes + timeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timeBytes, uint64(timeStep)) + + // Calculate HMAC + hasher.Write(timeBytes) + hash := hasher.Sum(nil) + + // Get offset + offset := hash[len(hash)-1] & 0xf + + // Generate 4-byte code + code := binary.BigEndian.Uint32(hash[offset : offset+4]) + code = code & 0x7fffffff + + // Get the specified number of digits + code = code % uint32(pow10(digits)) + + // Format code with leading zeros + return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil +} + +func pow10(n int) uint32 { + result := uint32(1) + for i := 0; i < n; i++ { + result *= 10 + } + return result +} diff --git a/utils/utils.go b/utils/utils.go index 113ac27..efd6713 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,20 +1,65 @@ package utils import ( + "context" "crypto/aes" "crypto/cipher" "encoding/base64" + "fmt" "net/http" + "strings" + "github.com/golang-jwt/jwt" "github.com/julienschmidt/httprouter" + "github.com/spf13/viper" ) +// AdaptHandler函数将一个http.Handler转换为httprouter.Handle func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle { + // 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数 return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + // 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递 h(w, r) } } +func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized) + return + } + + tokenStr := strings.TrimPrefix(authHeader, "Bearer ") + secret := viper.GetString("auth.secret") + + token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") + } + return []byte(secret), nil + }) + + if err != nil || !token.Valid { + http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized) + return + } + + type contextKey string + // 将 openid 存入上下文 + ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"]) + next.ServeHTTP(w, r.WithContext(ctx)) + } +} + +// AesDecrypt 函数用于AES解密 func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) { //Base64解码 keyBytes, err := base64.StdEncoding.DecodeString(sessionKey) @@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) { return origData, nil } +// PKCS7UnPadding 函数用于去除PKCS7填充的密文 func PKCS7UnPadding(plantText []byte) []byte { + // 获取密文的长度 length := len(plantText) + // 如果密文长度大于0 if length > 0 { + // 获取最后一个字节的值,即填充的位数 unPadding := int(plantText[length-1]) + // 返回去除填充后的密文 return plantText[:(length - unPadding)] } + // 如果密文长度为0,则返回原密文 return plantText } diff --git a/validator/validator.go b/validator/validator.go new file mode 100644 index 0000000..13cff58 --- /dev/null +++ b/validator/validator.go @@ -0,0 +1,316 @@ +package validator + +import ( + "encoding/json" + "fmt" + "net/http" + "reflect" + "regexp" + "strings" + + "github.com/go-playground/validator/v10" +) + +var ( + validate *validator.Validate + + // 自定义验证规则 + customValidations = map[string]validator.Func{ + "otpsecret": validateOTPSecret, + "password": validatePassword, + "issuer": validateIssuer, + "otpauth_uri": validateOTPAuthURI, + "no_xss": validateNoXSS, + } + + // 常见的弱密码列表(实际使用时应该使用更完整的列表) + commonPasswords = map[string]bool{ + "password123": true, + "12345678": true, + "qwerty123": true, + "admin123": true, + "letmein": true, + "welcome": true, + "password": true, + "admin": true, + } + + // 预编译的XSS检测正则表达式 + xssPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)]*>.*?`), + regexp.MustCompile(`(?i)javascript:`), + regexp.MustCompile(`(?i)data:text/html`), + regexp.MustCompile(`(?i)on\w+\s*=`), + regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), + regexp.MustCompile(`(?i)<\s*iframe`), + regexp.MustCompile(`(?i)<\s*object`), + regexp.MustCompile(`(?i)<\s*embed`), + regexp.MustCompile(`(?i)<\s*style`), + regexp.MustCompile(`(?i)<\s*form`), + regexp.MustCompile(`(?i)<\s*applet`), + regexp.MustCompile(`(?i)<\s*meta`), + regexp.MustCompile(`(?i)expression\s*\(`), + regexp.MustCompile(`(?i)url\s*\(`), + } + + // 预编译的正则表达式 + base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`) + issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`) + otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`) + upperRegex = regexp.MustCompile(`[A-Z]`) + lowerRegex = regexp.MustCompile(`[a-z]`) + numberRegex = regexp.MustCompile(`[0-9]`) + specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`) +) + +func init() { + validate = validator.New() + + // 注册自定义验证规则 + for tag, fn := range customValidations { + if err := validate.RegisterValidation(tag, fn); err != nil { + panic(fmt.Sprintf("failed to register validation %s: %v", tag, err)) + } + } + + // 使用json tag作为字段名 + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) +} + +// ValidateRequest validates a request body against a struct +func ValidateRequest(r *http.Request, v interface{}) error { + if err := json.NewDecoder(r.Body).Decode(v); err != nil { + return fmt.Errorf("invalid request body: %w", err) + } + + if err := validate.Struct(v); err != nil { + if validationErrors, ok := err.(validator.ValidationErrors); ok { + return NewValidationError(validationErrors) + } + return fmt.Errorf("validation error: %w", err) + } + + return nil +} + +// ValidationError represents a validation error +type ValidationError struct { + Fields map[string]string `json:"fields"` +} + +// Error implements the error interface +func (e *ValidationError) Error() string { + var errors []string + for field, msg := range e.Fields { + errors = append(errors, fmt.Sprintf("%s: %s", field, msg)) + } + return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; ")) +} + +// NewValidationError creates a new ValidationError from validator.ValidationErrors +func NewValidationError(errors validator.ValidationErrors) *ValidationError { + fields := make(map[string]string) + for _, err := range errors { + fields[err.Field()] = getErrorMessage(err) + } + return &ValidationError{Fields: fields} +} + +// getErrorMessage returns a human-readable error message for a validation error +func getErrorMessage(err validator.FieldError) string { + switch err.Tag() { + case "required": + return "此字段为必填项" + case "email": + return "请输入有效的电子邮件地址" + case "min": + if err.Type().Kind() == reflect.String { + return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param()) + } + return fmt.Sprintf("必须大于或等于 %s", err.Param()) + case "max": + if err.Type().Kind() == reflect.String { + return fmt.Sprintf("长度不能超过 %s 个字符", err.Param()) + } + return fmt.Sprintf("必须小于或等于 %s", err.Param()) + case "len": + return fmt.Sprintf("长度必须为 %s 个字符", err.Param()) + case "oneof": + return fmt.Sprintf("必须是以下值之一: %s", err.Param()) + case "otpsecret": + return "OTP密钥格式无效,必须是有效的Base32编码" + case "password": + return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符" + case "issuer": + return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号" + case "otpauth_uri": + return "OTP认证URI格式无效" + case "no_xss": + return "输入包含潜在的不安全内容" + case "numeric": + return "必须是数字" + default: + return fmt.Sprintf("验证失败: %s", err.Tag()) + } +} + +// Custom validation functions + +// validateOTPSecret validates an OTP secret +func validateOTPSecret(fl validator.FieldLevel) bool { + secret := fl.Field().String() + + if secret == "" { + return false + } + + // OTP secret should be base32 encoded + if !base32Regex.MatchString(secret) { + return false + } + + // Check length (typical OTP secrets are 16-64 characters) + validLength := len(secret) >= 16 && len(secret) <= 128 + + return validLength +} + +// validatePassword validates a password +func validatePassword(fl validator.FieldLevel) bool { + password := fl.Field().String() + + // At least 10 characters long + if len(password) < 10 { + return false + } + + // Check if it's a common password + if commonPasswords[strings.ToLower(password)] { + return false + } + + // Check character types + hasUpper := upperRegex.MatchString(password) + hasLower := lowerRegex.MatchString(password) + hasNumber := numberRegex.MatchString(password) + hasSpecial := specialRegex.MatchString(password) + + // Ensure password has enough complexity + complexity := 0 + if hasUpper { + complexity++ + } + if hasLower { + complexity++ + } + if hasNumber { + complexity++ + } + if hasSpecial { + complexity++ + } + + return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial) +} + +// validateIssuer validates an issuer name +func validateIssuer(fl validator.FieldLevel) bool { + issuer := fl.Field().String() + + if issuer == "" { + return false + } + + // Issuer should not contain special characters that could cause problems in URLs + if !issuerRegex.MatchString(issuer) { + return false + } + + // Check length + validLength := len(issuer) >= 1 && len(issuer) <= 100 + + return validLength +} + +// validateOTPAuthURI validates an otpauth URI +func validateOTPAuthURI(fl validator.FieldLevel) bool { + uri := fl.Field().String() + + if uri == "" { + return false + } + + // Basic format check for otpauth URI + // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD + return otpauthRegex.MatchString(uri) +} + +// validateNoXSS checks if a string contains potential XSS payloads +func validateNoXSS(fl validator.FieldLevel) bool { + value := fl.Field().String() + + // 检查基本的HTML编码 + if strings.Contains(value, "&#") || + strings.Contains(value, "<") || + strings.Contains(value, ">") { + return false + } + + // 检查十六进制编码 + if strings.Contains(strings.ToLower(value), "\\x3c") || // < + strings.Contains(strings.ToLower(value), "\\x3e") { // > + return false + } + + // 检查Unicode编码 + if strings.Contains(strings.ToLower(value), "\\u003c") || // < + strings.Contains(strings.ToLower(value), "\\u003e") { // > + return false + } + + // 使用预编译的正则表达式检查XSS模式 + for _, pattern := range xssPatterns { + if pattern.MatchString(value) { + return false + } + } + + return true +} + +// Request validation structs + +// LoginRequest represents a login request +type LoginRequest struct { + Code string `json:"code" validate:"required,len=6|len=8,numeric"` +} + +// CreateOTPRequest represents a request to create an OTP +type CreateOTPRequest struct { + Name string `json:"name" validate:"required,min=1,max=100,no_xss"` + Issuer string `json:"issuer" validate:"required,issuer,no_xss"` + Secret string `json:"secret" validate:"required,otpsecret"` + Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" validate:"required,oneof=6 8"` + Period int `json:"period" validate:"required,oneof=30 60"` +} + +// UpdateOTPRequest represents a request to update an OTP +type UpdateOTPRequest struct { + Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"` + Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"` + Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" validate:"omitempty,oneof=6 8"` + Period int `json:"period" validate:"omitempty,oneof=30 60"` +} + +// VerifyOTPRequest represents a request to verify an OTP code +type VerifyOTPRequest struct { + Code string `json:"code" validate:"required,len=6|len=8,numeric"` +}