Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
“xHuPo”
5d370e1077 error 2025-05-27 17:44:24 +08:00
“xHuPo”
44500afd3f beta 2025-05-23 19:09:06 +08:00
“xHuPo”
bcd986e3f7 beta 2025-05-23 18:57:11 +08:00
“xHuPo”
a45ddf13d5 alpha 2025-05-23 14:15:48 +08:00
“xHuPo”
a6461a9a0e alpha 2025-05-23 13:45:53 +08:00
“xHuPo”
2d3698716e alpha 2025-05-23 13:45:37 +08:00
“xHuPo”
25c5f530b8 alpha 2025-05-22 16:07:55 +08:00
“xHuPo”
079542e431 alpha 2025-05-22 12:06:34 +08:00
50 changed files with 6350 additions and 364 deletions

149
api/response.go Normal file
View file

@ -0,0 +1,149 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"net/http"
)
// Common error codes
const (
CodeSuccess = 0
CodeInvalidParams = 400
CodeUnauthorized = 401
CodeForbidden = 403
CodeNotFound = 404
CodeInternalError = 500
CodeServiceUnavail = 503
)
// Error represents an API error
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
}
// Error implements the error interface
func (e *Error) Error() string {
return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
}
// NewError creates a new API error
func NewError(code int, message string) *Error {
return &Error{
Code: code,
Message: message,
}
}
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ResponseWriter wraps common response writing functions
type ResponseWriter struct {
http.ResponseWriter
}
// NewResponseWriter creates a new ResponseWriter
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{w}
}
// WriteJSON writes a JSON response
func (w *ResponseWriter) WriteJSON(code int, data interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
return json.NewEncoder(w).Encode(data)
}
// WriteSuccess writes a success response
func (w *ResponseWriter) WriteSuccess(data interface{}) error {
return w.WriteJSON(http.StatusOK, Response{
Code: CodeSuccess,
Message: "success",
Data: data,
})
}
// WriteError writes an error response
func (w *ResponseWriter) WriteError(err error) error {
var apiErr *Error
if errors.As(err, &apiErr) {
return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{
Code: apiErr.Code,
Message: apiErr.Message,
})
}
// Handle unknown errors
return w.WriteJSON(http.StatusInternalServerError, Response{
Code: CodeInternalError,
Message: "Internal Server Error",
})
}
// WriteErrorWithCode writes an error response with a specific code
func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error {
return w.WriteJSON(getHTTPStatus(code), Response{
Code: code,
Message: message,
})
}
// getHTTPStatus maps API error codes to HTTP status codes
func getHTTPStatus(code int) int {
switch code {
case CodeSuccess:
return http.StatusOK
case CodeInvalidParams:
return http.StatusBadRequest
case CodeUnauthorized:
return http.StatusUnauthorized
case CodeForbidden:
return http.StatusForbidden
case CodeNotFound:
return http.StatusNotFound
case CodeServiceUnavail:
return http.StatusServiceUnavailable
default:
return http.StatusInternalServerError
}
}
// Common errors
var (
ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters")
ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized")
ErrForbidden = NewError(CodeForbidden, "Forbidden")
ErrNotFound = NewError(CodeNotFound, "Resource not found")
ErrInternalError = NewError(CodeInternalError, "Internal server error")
ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable")
)
// ValidationError creates an error for invalid parameters
func ValidationError(message string) *Error {
return NewError(CodeInvalidParams, message)
}
// NotFoundError creates an error for not found resources
func NotFoundError(resource string) *Error {
return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource))
}
// ForbiddenError creates an error for forbidden actions
func ForbiddenError(message string) *Error {
return NewError(CodeForbidden, message)
}
// InternalError creates an error for internal server errors
func InternalError(err error) *Error {
if err == nil {
return ErrInternalError
}
return NewError(CodeInternalError, err.Error())
}

152
api/validator.go Normal file
View file

@ -0,0 +1,152 @@
package api
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
// Validate is a global validator instance
var Validate = validator.New()
// RegisterCustomValidations registers custom validation functions
func RegisterCustomValidations() {
// Register custom validation for issuer
Validate.RegisterValidation("issuer", validateIssuer)
// Register custom validation for XSS prevention
Validate.RegisterValidation("no_xss", validateNoXSS)
// Register custom validation for OTP secret
Validate.RegisterValidation("otpsecret", validateOTPSecret)
}
// validateOTPSecret validates that the OTP secret is in valid base32 format
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// Check if the secret is not empty
if secret == "" {
return false
}
// Check if the secret is in base32 format (A-Z, 2-7)
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
if !base32Regex.MatchString(secret) {
return false
}
// Check if the length is valid (must be at least 16 characters)
if len(secret) < 16 || len(secret) > 128 {
return false
}
return true
}
// validateIssuer validates that the issuer field contains only allowed characters
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
// Empty issuer is valid (since it's optional)
if issuer == "" {
return true
}
// Allow alphanumeric characters, spaces, and common punctuation
issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
// Validate is a global validator instance
var Validate = validator.New()
// RegisterCustomValidations registers custom validation functions
func RegisterCustomValidations() {
// Register custom validation for issuer
Validate.RegisterValidation("issuer", validateIssuer)
// Register custom validation for XSS prevention
Validate.RegisterValidation("no_xss", validateNoXSS)
// Register custom validation for OTP secret
Validate.RegisterValidation("otpsecret", validateOTPSecret)
}
// validateOTPSecret validates that the OTP secret is in valid base32 format
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// Check if the secret is not empty
if secret == "" {
return false
}
// Check if the secret is in base32 format (A-Z, 2-7)
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
if !base32Regex.MatchString(secret) {
return false
}
// Check if the length is valid (must be at least 16 characters)
if len(secret) < 16 || len(secret) > 128 {
return false
}
return true
}
)
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
if len(issuer) > 100 {
return false
}
return true
}
// validateNoXSS validates that the field doesn't contain potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// Check for HTML encoding
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// Check for common XSS patterns
suspiciousPatterns := []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
}
for _, pattern := range suspiciousPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}

206
cache/cache.go vendored Normal file
View file

@ -0,0 +1,206 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
)
// Item represents a cache item
type Item struct {
Value []byte
Expiration int64
}
// Expired returns true if the item has expired
func (item Item) Expired() bool {
if item.Expiration == 0 {
return false
}
return time.Now().UnixNano() > item.Expiration
}
// Cache represents an in-memory cache
type Cache struct {
items map[string]Item
mu sync.RWMutex
defaultExpiration time.Duration
cleanupInterval time.Duration
stopCleanup chan bool
}
// New creates a new cache with the given default expiration and cleanup interval
func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
cache := &Cache{
items: make(map[string]Item),
defaultExpiration: defaultExpiration,
cleanupInterval: cleanupInterval,
stopCleanup: make(chan bool),
}
// Start cleanup goroutine if cleanup interval > 0
if cleanupInterval > 0 {
go cache.startCleanup()
}
return cache
}
// startCleanup starts the cleanup process
func (c *Cache) startCleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.DeleteExpired()
case <-c.stopCleanup:
return
}
}
}
// Set adds an item to the cache with the given key and expiration
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error {
// Convert value to bytes
var valueBytes []byte
var err error
switch v := value.(type) {
case []byte:
valueBytes = v
case string:
valueBytes = []byte(v)
default:
valueBytes, err = json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
}
// Calculate expiration
var exp int64
if expiration == 0 {
if c.defaultExpiration > 0 {
exp = time.Now().Add(c.defaultExpiration).UnixNano()
}
} else if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
c.mu.Lock()
c.items[key] = Item{
Value: valueBytes,
Expiration: exp,
}
c.mu.Unlock()
return nil
}
// Get gets an item from the cache
func (c *Cache) Get(key string, value interface{}) (bool, error) {
c.mu.RLock()
item, found := c.items[key]
c.mu.RUnlock()
if !found {
return false, nil
}
// Check if item has expired
if item.Expired() {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return false, nil
}
// Unmarshal value
switch v := value.(type) {
case *[]byte:
*v = item.Value
case *string:
*v = string(item.Value)
default:
if err := json.Unmarshal(item.Value, value); err != nil {
return true, fmt.Errorf("failed to unmarshal value: %w", err)
}
}
return true, nil
}
// Delete deletes an item from the cache
func (c *Cache) Delete(key string) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
}
// DeleteExpired deletes all expired items from the cache
func (c *Cache) DeleteExpired() {
now := time.Now().UnixNano()
c.mu.Lock()
for k, v := range c.items {
if v.Expiration > 0 && now > v.Expiration {
delete(c.items, k)
}
}
c.mu.Unlock()
}
// Clear deletes all items from the cache
func (c *Cache) Clear() {
c.mu.Lock()
c.items = make(map[string]Item)
c.mu.Unlock()
}
// Close stops the cleanup goroutine
func (c *Cache) Close() {
if c.cleanupInterval > 0 {
c.stopCleanup <- true
}
}
// CacheService provides caching functionality
type CacheService struct {
cache *Cache
}
// NewCacheService creates a new CacheService
func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService {
return &CacheService{
cache: New(defaultExpiration, cleanupInterval),
}
}
// Set adds an item to the cache
func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return s.cache.Set(key, value, expiration)
}
// Get gets an item from the cache
func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) {
return s.cache.Get(key, value)
}
// Delete deletes an item from the cache
func (s *CacheService) Delete(ctx context.Context, key string) {
s.cache.Delete(key)
}
// Clear deletes all items from the cache
func (s *CacheService) Clear(ctx context.Context) {
s.cache.Clear()
}
// Close closes the cache
func (s *CacheService) Close() {
s.cache.Close()
}

View file

@ -1,83 +1,144 @@
package cmd
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/spf13/viper"
"otpm/api"
"otpm/config"
"otpm/database"
"otpm/handlers"
"otpm/utils"
"github.com/jmoiron/sqlx"
"github.com/julienschmidt/httprouter"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"otpm/models"
"otpm/server"
"otpm/services"
)
var rootCmd = &cobra.Command{
Use: "otpm",
Short: "otp backend for microapp on wechat",
Run: func(cmd *cobra.Command, args []string) {
startApp()
},
}
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func init() {
cobra.OnInitialize(initConfig)
// Set config file with multi-environment support
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)")
rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)")
rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string")
rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on")
viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver"))
viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn"))
viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
}
func initConfig() {
if cfgFile := viper.GetString("config"); cfgFile != "" {
viper.SetConfigFile(cfgFile)
} else {
viper.AddConfigPath(".")
viper.SetConfigName("config")
viper.SetConfigType("yaml")
// Set environment specific config (e.g. config.production.yaml)
env := os.Getenv("OTPM_ENV")
if env != "" {
viper.SetConfigName(fmt.Sprintf("config.%s", env))
}
if err := viper.ReadInConfig(); err != nil {
log.Fatalf("Error reading config file: %v", err)
}
// Set default values
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.timeout.read", "15s")
viper.SetDefault("server.timeout.write", "15s")
viper.SetDefault("server.timeout.idle", "60s")
viper.SetDefault("database.max_open_conns", 25)
viper.SetDefault("database.max_idle_conns", 5)
viper.SetDefault("database.conn_max_lifetime", "5m")
// Set environment variable prefix
viper.SetEnvPrefix("OTPM")
viper.AutomaticEnv()
// Bind environment variables
viper.BindEnv("database.url", "OTPM_DB_URL")
viper.BindEnv("database.password", "OTPM_DB_PASSWORD")
}
func initApp(db *sqlx.DB) {
if err := database.MigrateDB(db); err != nil {
log.Fatalf("Error migrating the database: %v", err)
}
}
func startApp() {
port := viper.GetInt("port")
db, err := database.InitDB()
// Execute is the entry point for the application
func Execute() error {
// Load configuration
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("Error connecting to the database: %v", err)
return fmt.Errorf("failed to load config: %w", err)
}
// Create context with cancellation
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Printf("Received signal: %v", sig)
cancel()
}()
// Initialize database
db, err := database.New(&cfg.Database)
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
defer db.Close()
initApp(db)
handler := &handlers.Handler{DB: db}
router := httprouter.New()
router.POST("/login", utils.AdaptHandler(handler.Login))
router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
router.GET("/get", utils.AdaptHandler(handler.GetOtp))
// Run database migrations
if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
log.Println("Starting server on :8080")
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), router))
// Initialize repositories
userRepo := models.NewUserRepository(db.DB)
otpRepo := models.NewOTPRepository(db.DB)
// Initialize services
authService := services.NewAuthService(cfg, userRepo)
otpService := services.NewOTPService(otpRepo)
// Register custom validations
api.RegisterCustomValidations()
// Initialize handlers
authHandler := handlers.NewAuthHandler(authService)
otpHandler := handlers.NewOTPHandler(otpService)
// Create and configure server
srv := server.New(cfg)
// Register health check endpoint
srv.RegisterHealthCheck()
// Register public routes with type conversion
authRoutes := make(map[string]http.Handler)
for path, handler := range authHandler.Routes() {
authRoutes[path] = http.HandlerFunc(handler)
}
srv.RegisterRoutes(authRoutes)
// Register authenticated routes with type conversion
otpRoutes := make(map[string]http.Handler)
for path, handler := range otpHandler.Routes() {
otpRoutes[path] = http.HandlerFunc(handler)
}
srv.RegisterAuthRoutes(otpRoutes)
// Start server in goroutine
serverErr := make(chan error, 1)
go func() {
log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port)
if err := srv.Start(); err != nil {
serverErr <- fmt.Errorf("server error: %w", err)
}
}()
// Wait for shutdown signal or server error
select {
case err := <-serverErr:
return err
case <-ctx.Done():
// Graceful shutdown with timeout
log.Println("Shutting down server...")
if err := srv.Shutdown(); err != nil {
return fmt.Errorf("server shutdown error: %w", err)
}
log.Println("Server stopped gracefully")
}
return nil
}

View file

@ -1,8 +1,23 @@
server:
port: 8080
read_timeout: 15s
write_timeout: 15s
shutdown_timeout: 5s
database:
driver: sqlite
driver: sqlite3
dsn: otpm.sqlite
port: 8080
max_open_conns: 25
max_idle_conns: 25
max_lifetime: 5m
skip_migration: false
jwt:
secret: "your-jwt-secret-key-change-this-in-production"
expire_delta: 24h
refresh_delta: 168h
signing_method: HS256
wechat:
appid: "wx57d1033974eb5250"
secret: "be494c2a81df685a40b9a74e1736b15d"
app_id: "your-wechat-app-id"
app_secret: "your-wechat-app-secret"

128
config/config.go Normal file
View file

@ -0,0 +1,128 @@
package config
import (
"fmt"
"time"
"github.com/spf13/viper"
)
// Config holds all configuration for the application
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
JWT JWTConfig `mapstructure:"jwt"`
WeChat WeChatConfig `mapstructure:"wechat"`
}
// ServerConfig holds all server related configuration
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout
}
// DatabaseConfig holds all database related configuration
type DatabaseConfig struct {
Driver string `mapstructure:"driver"`
DSN string `mapstructure:"dsn"`
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
MaxLifetime time.Duration `mapstructure:"max_lifetime"`
SkipMigration bool `mapstructure:"skip_migration"`
}
// JWTConfig holds all JWT related configuration
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireDelta time.Duration `mapstructure:"expire_delta"`
RefreshDelta time.Duration `mapstructure:"refresh_delta"`
SigningMethod string `mapstructure:"signing_method"`
Issuer string `mapstructure:"issuer"`
Audience string `mapstructure:"audience"`
}
// WeChatConfig holds all WeChat related configuration
type WeChatConfig struct {
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
}
// LoadConfig loads the configuration from file and environment variables
func LoadConfig() (*Config, error) {
// Set default values
setDefaults()
// Read config file
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Validate config
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
return &config, nil
}
// setDefaults sets default values for configuration
func setDefaults() {
// Server defaults
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.read_timeout", "15s")
viper.SetDefault("server.write_timeout", "15s")
viper.SetDefault("server.shutdown_timeout", "5s")
viper.SetDefault("server.timeout", "30s") // Default request processing timeout
// Database defaults
viper.SetDefault("database.driver", "sqlite3")
viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
viper.SetDefault("database.skip_migration", false)
// JWT defaults
viper.SetDefault("jwt.expire_delta", "24h")
viper.SetDefault("jwt.refresh_delta", "168h") // 7 days
viper.SetDefault("jwt.signing_method", "HS256")
viper.SetDefault("jwt.issuer", "otpm")
viper.SetDefault("jwt.audience", "otpm-client")
}
// validateConfig validates the configuration
func validateConfig(config *Config) error {
if config.Server.Port < 1 || config.Server.Port > 65535 {
return fmt.Errorf("invalid port number: %d", config.Server.Port)
}
if config.Database.Driver == "" {
return fmt.Errorf("database driver is required")
}
if config.Database.DSN == "" {
return fmt.Errorf("database DSN is required")
}
if config.JWT.Secret == "" {
return fmt.Errorf("JWT secret is required")
}
if config.WeChat.AppID == "" {
return fmt.Errorf("WeChat AppID is required")
}
if config.WeChat.AppSecret == "" {
return fmt.Errorf("WeChat AppSecret is required")
}
return nil
}

View file

@ -1,49 +0,0 @@
package database
import (
_ "embed"
"log"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/spf13/viper"
_ "modernc.org/sqlite"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
func InitDB() (*sqlx.DB, error) {
driver := viper.GetString("database.driver")
dsn := viper.GetString("database.dsn")
db, err := sqlx.Open(driver, dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
log.Println("Connected to database!")
return db, nil
}
func MigrateDB(db *sqlx.DB) error {
_, err := db.Exec(userTable)
if err != nil {
return err
}
_, err = db.Exec(otpTable)
if err != nil {
return err
}
return nil
}

212
database/db.go Normal file
View file

@ -0,0 +1,212 @@
package database
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"otpm/config"
"github.com/jmoiron/sqlx"
)
// DB wraps sqlx.DB to provide additional functionality
type DB struct {
*sqlx.DB
}
// New creates a new database connection
func New(cfg *config.DatabaseConfig) (*DB, error) {
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Configure connection pool based on database type
if cfg.Driver == "sqlite3" {
// SQLite is a file-based database - simpler connection settings
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0) // Connections don't need to be recycled
db.SetConnMaxIdleTime(0)
} else {
// For other databases (MySQL, PostgreSQL etc.)
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
db.SetConnMaxLifetime(30 * time.Minute)
db.SetConnMaxIdleTime(5 * time.Minute)
}
// Verify connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Println("Successfully connected to database")
return &DB{db}, nil
}
// WithTx executes a function within a transaction with retry logic
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
var maxRetries int
var lastErr error
// Adjust retry settings based on database type
if db.DriverName() == "sqlite3" {
maxRetries = 5 // SQLite needs more retries due to busy timeouts
} else {
maxRetries = 3
}
// Default transaction options
opts := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
}
for attempt := 1; attempt <= maxRetries; attempt++ {
start := time.Now()
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
continue
}
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
}
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
}
// Log long-running transactions
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
log.Printf("Transaction completed in %v", elapsed)
}
if err := tx.Commit(); err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
return nil
}
return lastErr
}
// isRetryableError checks if an error is likely to succeed on retry
func isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "deadlock") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "try again") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "busy") ||
strings.Contains(errStr, "locked")
}
// ExecContext executes a query with adaptive timeout
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "insert") ||
strings.Contains(strings.ToLower(query), "update") ||
strings.Contains(strings.ToLower(query), "delete") {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
_, err := db.DB.ExecContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
}
if err != nil {
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return nil
}
// QueryRowContext executes a query that returns a single row with timeout
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return db.DB.QueryRowxContext(ctx, query, args...)
}
// QueryContext executes a query that returns multiple rows with adaptive timeout
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "join") ||
strings.Contains(strings.ToLower(query), "group by") ||
strings.Contains(strings.ToLower(query), "order by") {
timeout = 15 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
rows, err := db.DB.QueryxContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
}
if err != nil {
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return rows, nil
}
// Close closes the database connection
func (db *DB) Close() error {
if err := db.DB.Close(); err != nil {
return fmt.Errorf("failed to close database connection: %w", err)
}
return nil
}

View file

@ -1,6 +1,26 @@
CREATE TABLE IF NOT EXISTS otp (
id SERIAL PRIMARY KEY,
openid VARCHAR(255),
num INTEGER,
token VARCHAR(255)
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id VARCHAR(255) NOT NULL,
openid VARCHAR(255) NOT NULL,
name VARCHAR(100) NOT NULL,
issuer VARCHAR(255),
secret VARCHAR(255) NOT NULL,
algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
period INTEGER NOT NULL DEFAULT 30,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, name),
UNIQUE(openid)
);
-- Add index for faster lookups
CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
-- Trigger to update the updated_at timestamp
CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
AFTER UPDATE ON otp
BEGIN
UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END;

View file

@ -1,5 +1,6 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
id INTEGER PRIMARY KEY AUTOINCREMENT,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
);
CREATE UNIQUE INDEX idx_users_openid ON users(openid);

160
database/migration.go Normal file
View file

@ -0,0 +1,160 @@
package database
import (
"context"
"fmt"
"log"
"time"
_ "embed"
"github.com/jmoiron/sqlx"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
// Migration represents a database migration
type Migration struct {
Name string
SQL string
Version int
}
// Migrations is a list of all migrations
var Migrations = []Migration{
{
Name: "Create users table",
SQL: userTable,
Version: 1,
},
{
Name: "Create OTP table",
SQL: otpTable,
Version: 2,
},
}
// MigrationRecord represents a record in the migrations table
type MigrationRecord struct {
ID int `db:"id"`
Version int `db:"version"`
Name string `db:"name"`
AppliedAt time.Time `db:"applied_at"`
}
// ensureMigrationsTable ensures that the migrations table exists
func ensureMigrationsTable(db *sqlx.DB) error {
query := `
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version INTEGER NOT NULL,
name TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
return nil
}
// getAppliedMigrations gets all applied migrations
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
var records []MigrationRecord
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
result := make(map[int]MigrationRecord)
for _, record := range records {
result[record.Version] = record
}
return result, nil
}
// Migrate runs all pending migrations
func Migrate(db *sqlx.DB, skipMigration bool) error {
if skipMigration {
log.Println("Skipping database migration as configured")
return nil
}
// Ensure migrations table exists
if err := ensureMigrationsTable(db); err != nil {
return err
}
// Get applied migrations
appliedMigrations, err := getAppliedMigrations(db)
if err != nil {
return err
}
// Apply pending migrations
for _, migration := range Migrations {
if _, ok := appliedMigrations[migration.Version]; ok {
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
continue
}
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
// Start a transaction for this migration
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// Execute migration
if _, err := tx.Exec(migration.SQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Record migration
if _, err := tx.Exec(
"INSERT INTO migrations (version, name) VALUES (?, ?)",
migration.Version, migration.Name,
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
}
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
}
return nil
}
// MigrateWithContext runs all pending migrations with context
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
// Create a channel to signal completion
done := make(chan error, 1)
// Run migration in a goroutine
go func() {
done <- Migrate(db, skipMigration)
}()
// Wait for migration to complete or context to be canceled
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("migration canceled: %w", ctx.Err())
}
}

663
docs/swagger.go Normal file
View file

@ -0,0 +1,663 @@
// Package docs provides API documentation using Swagger/OpenAPI
package docs
import (
"encoding/json"
"net/http"
)
// SwaggerInfo holds the API information
var SwaggerInfo = struct {
Title string
Description string
Version string
Host string
BasePath string
Schemes []string
}{
Title: "OTPM API",
Description: "API for One-Time Password Manager",
Version: "1.0.0",
Host: "localhost:8080",
BasePath: "/",
Schemes: []string{"http", "https"},
}
// SwaggerJSON returns the Swagger JSON
func SwaggerJSON() []byte {
swagger := map[string]interface{}{
"swagger": "2.0",
"info": map[string]interface{}{
"title": SwaggerInfo.Title,
"description": SwaggerInfo.Description,
"version": SwaggerInfo.Version,
},
"host": SwaggerInfo.Host,
"basePath": SwaggerInfo.BasePath,
"schemes": SwaggerInfo.Schemes,
"paths": getPaths(),
"definitions": map[string]interface{}{
"LoginRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "WeChat authorization code",
},
},
"required": []string{"code"},
},
"LoginResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"token": map[string]interface{}{
"type": "string",
"description": "JWT token",
},
"openid": map[string]interface{}{
"type": "string",
"description": "WeChat OpenID",
},
},
},
"CreateOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"secret": map[string]interface{}{
"type": "string",
"description": "OTP secret",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
"enum": []string{"SHA1", "SHA256", "SHA512"},
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
"enum": []int{6, 8},
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
"enum": []int{30, 60},
},
},
"required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"},
},
"OTP": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "string",
"description": "OTP ID",
},
"user_id": map[string]interface{}{
"type": "string",
"description": "User ID",
},
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
},
"created_at": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Creation time",
},
"updated_at": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Last update time",
},
},
},
"OTPCodeResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "OTP code",
},
"expires_in": map[string]interface{}{
"type": "integer",
"description": "Seconds until expiration",
},
},
},
"VerifyOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "OTP code to verify",
},
},
"required": []string{"code"},
},
"VerifyOTPResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"valid": map[string]interface{}{
"type": "boolean",
"description": "Whether the code is valid",
},
},
},
"UpdateOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
"enum": []string{"SHA1", "SHA256", "SHA512"},
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
"enum": []int{6, 8},
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
"enum": []int{30, 60},
},
},
},
"ErrorResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "integer",
"description": "Error code",
},
"message": map[string]interface{}{
"type": "string",
"description": "Error message",
},
},
},
},
"securityDefinitions": map[string]interface{}{
"Bearer": map[string]interface{}{
"type": "apiKey",
"name": "Authorization",
"in": "header",
"description": "JWT token with Bearer prefix",
},
},
}
data, _ := json.MarshalIndent(swagger, "", " ")
return data
}
// getPaths returns the API paths
func getPaths() map[string]interface{} {
return map[string]interface{}{
"/login": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Login with WeChat",
"description": "Login with WeChat authorization code",
"tags": []string{"auth"},
"parameters": []map[string]interface{}{
{
"name": "body",
"in": "body",
"description": "Login request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/LoginRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Successful login",
"schema": map[string]interface{}{
"$ref": "#/definitions/LoginResponse",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/verify-token": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Verify token",
"description": "Verify JWT token",
"tags": []string{"auth"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Token is valid",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"valid": map[string]interface{}{
"type": "boolean",
"description": "Whether the token is valid",
},
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp": map[string]interface{}{
"get": map[string]interface{}{
"summary": "List OTPs",
"description": "List all OTPs for the authenticated user",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "List of OTPs",
"schema": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
"post": map[string]interface{}{
"summary": "Create OTP",
"description": "Create a new OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "body",
"in": "body",
"description": "OTP creation request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/CreateOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP created",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}": map[string]interface{}{
"put": map[string]interface{}{
"summary": "Update OTP",
"description": "Update an existing OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
{
"name": "body",
"in": "body",
"description": "OTP update request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/UpdateOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP updated",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
"delete": map[string]interface{}{
"summary": "Delete OTP",
"description": "Delete an existing OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP deleted",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "Success message",
},
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}/code": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Get OTP code",
"description": "Get the current OTP code",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP code",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTPCodeResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}/verify": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Verify OTP code",
"description": "Verify an OTP code",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
{
"name": "body",
"in": "body",
"description": "OTP verification request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/VerifyOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP verification result",
"schema": map[string]interface{}{
"$ref": "#/definitions/VerifyOTPResponse",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/health": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Health check",
"description": "Check if the API is healthy",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "API is healthy",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"status": map[string]interface{}{
"type": "string",
"description": "Health status",
},
"time": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Current time",
},
},
},
},
},
},
},
"/metrics": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Metrics",
"description": "Get application metrics",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Application metrics",
},
},
},
},
"/swagger.json": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Swagger JSON",
"description": "Get Swagger JSON",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Swagger JSON",
},
},
},
},
}
}
// Handler returns an HTTP handler for Swagger JSON
func Handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(SwaggerJSON())
}
}

43
go.mod
View file

@ -1,31 +1,36 @@
module otpm
go 1.21.1
go 1.23.0
toolchain go1.23.9
require (
github.com/go-sql-driver/mysql v1.8.1
github.com/go-playground/validator/v10 v10.26.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/julienschmidt/httprouter v1.3.0
github.com/lib/pq v1.10.9
github.com/spf13/cobra v1.8.1
github.com/prometheus/client_golang v1.22.0
github.com/spf13/viper v1.19.0
modernc.org/sqlite v1.32.0
golang.org/x/crypto v0.38.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
@ -36,14 +41,10 @@ require (
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
modernc.org/libc v1.55.3 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
modernc.org/strutil v1.2.0 // indirect
modernc.org/token v1.1.0 // indirect
)

110
go.sum
View file

@ -1,60 +1,76 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
@ -65,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
@ -79,56 +93,32 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ=
modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y=
modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U=
modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s=
modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA=
modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=
modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

156
handlers/auth_handler.go Normal file
View file

@ -0,0 +1,156 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"otpm/api"
"otpm/services"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
)
// AuthHandler handles authentication related requests
type AuthHandler struct {
authService *services.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required,min=32,max=128"`
}
// LoginResponse represents a login response
type LoginResponse struct {
Token string `json:"token"`
OpenID string `json:"openid"`
}
// TokenRequest represents a token verification request
type TokenRequest struct {
Token string `validate:"required,min=32"`
}
// Login handles WeChat login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Limit request body size to prevent DOS
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
// Parse and validate request
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("Login request parse error: %v", err)
return
}
// Validate using validator
if err := api.Validate.Struct(req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request parameters: %v", err))
log.Printf("Login request validation failed: %v", err)
return
}
// Login with WeChat code
token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("Login failed for code %s: %v", req.Code, err)
return
}
// Log successful login
log.Printf("Login successful for code %s (took %v)",
req.Code, time.Since(start))
// Return token
api.NewResponseWriter(w).WriteSuccess(LoginResponse{
Token: token,
})
}
// VerifyToken handles token verification
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Authorization header is required")
log.Printf("Token verification failed: missing Authorization header")
return
}
// Validate token format
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token format. Expected 'Bearer <token>'")
log.Printf("Token verification failed: invalid token format")
return
}
token := authHeader[7:]
// Validate token using validator
tokenReq := TokenRequest{Token: token}
if err := api.Validate.Struct(tokenReq); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token format")
log.Printf("Token verification failed: %v", err)
return
}
// Validate token
claims, err := h.authService.ValidateToken(token)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("Token verification failed for token %s: %v",
maskToken(token), err) // Mask token in logs
return
}
// Log successful verification
userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
if !ok {
log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
} else {
log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
}
// Token is valid
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": true,
})
}
// maskToken masks sensitive parts of token for logging
func maskToken(token string) string {
if len(token) < 8 {
return "****"
}
return token[:4] + "****" + token[len(token)-4:]
}
// Routes returns all routes for the auth handler
func (h *AuthHandler) Routes() map[string]httprouter.Handle {
return map[string]httprouter.Handle{
"/api/login": h.Login,
"/api/verify-token": h.VerifyToken,
}
}

View file

@ -1,9 +0,0 @@
package handlers
import (
"github.com/jmoiron/sqlx"
)
type Handler struct {
DB *sqlx.DB
}

View file

@ -1,93 +0,0 @@
package handlers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/spf13/viper"
)
var code2Session = "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code"
type LoginRequest struct {
Code string `json:"code"`
}
// 封装code2session接口返回数据
type LoginResponse struct {
OpenId string `json:"openid"`
SessionKey string `json:"session_key"`
UnionId string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
func getLoginResponse(code string) (*LoginResponse, error) {
appid := viper.GetString("wechat.appid")
secret := viper.GetString("wechat.secret")
url := fmt.Sprintf(code2Session, appid, secret, code)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var loginResponse LoginResponse
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
return nil, err
}
if loginResponse.ErrCode != 0 {
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
}
return &loginResponse, nil
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Failed to parse request body", http.StatusBadRequest)
return
}
loginResponse, err := getLoginResponse(req.Code)
if err != nil {
http.Error(w, "Failed to get session key", http.StatusInternalServerError)
return
}
// // 插入或更新用户的openid和sessionid
// query := `
// INSERT INTO users (openid, sessionid)
// VALUES ($1, $2)
// ON CONFLICT (openid) DO UPDATE SET sessionid = $2
// RETURNING id;
// `
// var ID int
// if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
// http.Error(w, "Failed to log in user", http.StatusInternalServerError)
// return
// }
data := map[string]interface{}{
"openid": loginResponse.OpenId,
"session_key": loginResponse.SessionKey,
}
respData, err := json.Marshal(data)
if err != nil {
http.Error(w, "Failed to marshal response data", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(respData))
}

View file

@ -1,61 +0,0 @@
package handlers
import (
"encoding/json"
"net/http"
)
type OtpRequest struct {
OpenID string `json:"openid"`
Num int `json:"num"`
Token *[]OTP `json:"token"`
}
type OTP struct {
Issuer string `json:"issuer"`
Remark string `json:"remark"`
Secret string `json:"secret"`
}
func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) {
var req OtpRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request payload", http.StatusBadRequest)
return
}
num := len(*req.Token)
// 插入或更新 OTP 记录
query := `
INSERT INTO otp (openid, num, token)
VALUES ($1, $2, $3)
`
_, err := h.DB.Exec(query, req.OpenID, req.Token, num)
if err != nil {
http.Error(w, "Failed to update or create OTP", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OTP updated or created successfully"))
}
func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
openid := r.URL.Query().Get("openid")
if openid == "" {
http.Error(w, "未登录", http.StatusBadRequest)
return
}
var otp OtpRequest
err := h.DB.Get(&otp, "SELECT token, num, openid FROM otp WHERE openid=$1", openid)
if err != nil {
http.Error(w, "OTP not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(otp)
}

114
handlers/otp_handler.go Normal file
View file

@ -0,0 +1,114 @@
package handlers
import (
"encoding/json"
"net/http"
"github.com/julienschmidt/httprouter"
"otpm/api"
"otpm/middleware"
"otpm/models"
"otpm/services"
)
// OTPHandler handles OTP-related HTTP requests
type OTPHandler struct {
otpService *services.OTPService
}
// NewOTPHandler creates a new OTPHandler
func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
return &OTPHandler{
otpService: otpService,
}
}
// Routes returns the routes for OTP operations
func (h *OTPHandler) Routes() map[string]httprouter.Handle {
return map[string]httprouter.Handle{
"POST /api/otp": h.CreateOTP,
"GET /api/otps": h.ListOTPs,
"GET /api/otp/:id": h.GetOTP,
}
}
// CreateOTP handles the creation of a new OTP
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Parse request body
var params models.OTPParams
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
return
}
// Validate request
if err := api.Validate.Struct(params); err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
return
}
// Create OTP
otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}
// ListOTPs handles listing all OTPs for a user
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTPs
otps, err := h.otpService.ListOTPs(r.Context(), userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otps)
}
// GetOTP handles getting a specific OTP
func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Get user ID from context
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := ps.ByName("id")
if otpID == "" {
api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
return
}
// Get OTP
otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}

204
logger/logger.go Normal file
View file

@ -0,0 +1,204 @@
package logger
import (
"context"
"fmt"
"io"
"os"
"runtime"
"strings"
"time"
"github.com/google/uuid"
)
// Level represents a log level
type Level int
const (
// DEBUG level
DEBUG Level = iota
// INFO level
INFO
// WARN level
WARN
// ERROR level
ERROR
// FATAL level
FATAL
)
// String returns the string representation of the log level
func (l Level) String() string {
switch l {
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
case FATAL:
return "FATAL"
default:
return "UNKNOWN"
}
}
// Logger represents a logger
type Logger struct {
level Level
output io.Writer
}
// contextKey is a type for context keys
type contextKey string
// requestIDKey is the key for request ID in context
const requestIDKey = contextKey("request_id")
// New creates a new logger
func New(level Level, output io.Writer) *Logger {
if output == nil {
output = os.Stdout
}
return &Logger{
level: level,
output: output,
}
}
// WithLevel creates a new logger with the specified level
func (l *Logger) WithLevel(level Level) *Logger {
return &Logger{
level: level,
output: l.output,
}
}
// WithOutput creates a new logger with the specified output
func (l *Logger) WithOutput(output io.Writer) *Logger {
return &Logger{
level: l.level,
output: output,
}
}
// log logs a message with the specified level
func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) {
if level < l.level {
return
}
// Get request ID from context
requestID := getRequestID(ctx)
// Get caller information
_, file, line, ok := runtime.Caller(2)
if !ok {
file = "unknown"
line = 0
}
// Extract just the filename
if idx := strings.LastIndex(file, "/"); idx >= 0 {
file = file[idx+1:]
}
// Format message
message := fmt.Sprintf(format, args...)
// Format log entry
timestamp := time.Now().Format(time.RFC3339)
logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n",
timestamp, level.String(), file, line, requestID, message)
// Write log entry
_, _ = l.output.Write([]byte(logEntry))
// Exit if fatal
if level == FATAL {
os.Exit(1)
}
}
// Debug logs a debug message
func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, DEBUG, format, args...)
}
// Info logs an info message
func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, INFO, format, args...)
}
// Warn logs a warning message
func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, WARN, format, args...)
}
// Error logs an error message
func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, ERROR, format, args...)
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, FATAL, format, args...)
}
// WithRequestID adds a request ID to the context
func WithRequestID(ctx context.Context) context.Context {
requestID := uuid.New().String()
return context.WithValue(ctx, requestIDKey, requestID)
}
// GetRequestID gets the request ID from the context
func GetRequestID(ctx context.Context) string {
return getRequestID(ctx)
}
// getRequestID gets the request ID from the context
func getRequestID(ctx context.Context) string {
if ctx == nil {
return "-"
}
requestID, ok := ctx.Value(requestIDKey).(string)
if !ok {
return "-"
}
return requestID
}
// Default logger
var defaultLogger = New(INFO, os.Stdout)
// SetDefaultLogger sets the default logger
func SetDefaultLogger(logger *Logger) {
defaultLogger = logger
}
// Debug logs a debug message using the default logger
func Debug(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Debug(ctx, format, args...)
}
// Info logs an info message using the default logger
func Info(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Info(ctx, format, args...)
}
// Warn logs a warning message using the default logger
func Warn(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Warn(ctx, format, args...)
}
// Error logs an error message using the default logger
func Error(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Error(ctx, format, args...)
}
// Fatal logs a fatal message and exits using the default logger
func Fatal(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Fatal(ctx, format, args...)
}

193
metrics/metrics.go Normal file
View file

@ -0,0 +1,193 @@
package metrics
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
// Default metrics
requestDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "Duration of HTTP requests in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
)
requestTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
)
otpGenerationTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "otp_generation_total",
Help: "Total number of OTP generations",
},
[]string{"user_id", "otp_id"},
)
otpVerificationTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "otp_verification_total",
Help: "Total number of OTP verifications",
},
[]string{"user_id", "otp_id", "success"},
)
activeUsers = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "active_users",
Help: "Number of active users",
},
)
cacheHits = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Help: "Total number of cache hits",
},
[]string{"cache"},
)
cacheMisses = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Help: "Total number of cache misses",
},
[]string{"cache"},
)
)
func init() {
// Register metrics with prometheus
prometheus.MustRegister(
requestDuration,
requestTotal,
otpGenerationTotal,
otpVerificationTotal,
activeUsers,
cacheHits,
cacheMisses,
)
}
// MetricsService provides metrics functionality
type MetricsService struct {
activeUsersMutex sync.RWMutex
activeUserIDs map[string]bool
}
// NewMetricsService creates a new MetricsService
func NewMetricsService() *MetricsService {
return &MetricsService{
activeUserIDs: make(map[string]bool),
}
}
// Handler returns an HTTP handler for metrics
func (s *MetricsService) Handler() http.Handler {
return promhttp.Handler()
}
// RecordRequest records metrics for an HTTP request
func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) {
labels := prometheus.Labels{
"method": method,
"path": path,
"status": fmt.Sprintf("%d", status),
}
requestDuration.With(labels).Observe(duration.Seconds())
requestTotal.With(labels).Inc()
}
// RecordOTPGeneration records metrics for OTP generation
func (s *MetricsService) RecordOTPGeneration(userID, otpID string) {
otpGenerationTotal.With(prometheus.Labels{
"user_id": userID,
"otp_id": otpID,
}).Inc()
}
// RecordOTPVerification records metrics for OTP verification
func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) {
otpVerificationTotal.With(prometheus.Labels{
"user_id": userID,
"otp_id": otpID,
"success": fmt.Sprintf("%t", success),
}).Inc()
}
// RecordUserActivity records user activity
func (s *MetricsService) RecordUserActivity(userID string) {
s.activeUsersMutex.Lock()
defer s.activeUsersMutex.Unlock()
if !s.activeUserIDs[userID] {
s.activeUserIDs[userID] = true
activeUsers.Inc()
}
}
// RecordUserInactivity records user inactivity
func (s *MetricsService) RecordUserInactivity(userID string) {
s.activeUsersMutex.Lock()
defer s.activeUsersMutex.Unlock()
if s.activeUserIDs[userID] {
delete(s.activeUserIDs, userID)
activeUsers.Dec()
}
}
// RecordCacheHit records a cache hit
func (s *MetricsService) RecordCacheHit(cache string) {
cacheHits.With(prometheus.Labels{
"cache": cache,
}).Inc()
}
// RecordCacheMiss records a cache miss
func (s *MetricsService) RecordCacheMiss(cache string) {
cacheMisses.With(prometheus.Labels{
"cache": cache,
}).Inc()
}
// Middleware creates a middleware that records request metrics
func (s *MetricsService) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Create response writer that captures status code
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
// Call next handler
next.ServeHTTP(rw, r)
// Record metrics
s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start))
})
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}

353
middleware/middleware.go Normal file
View file

@ -0,0 +1,353 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/golang-jwt/jwt"
)
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ErrorResponse sends a JSON error response
func ErrorResponse(w http.ResponseWriter, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(Response{
Code: code,
Message: message,
})
}
// SuccessResponse sends a JSON success response
func SuccessResponse(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
Code: http.StatusOK,
Message: "success",
Data: data,
})
}
// Logger is a middleware that logs request details with structured format
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
r.Header.Set("X-Request-ID", requestID)
}
// Create a custom response writer to capture status code
rw := &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
}
// Process request
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
// Log structured request details
log.Printf(
"method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
r.Method,
r.URL.Path,
rw.status,
time.Since(start).String(),
r.RemoteAddr,
requestID,
)
})
}
// generateRequestID creates a unique request identifier
func generateRequestID() string {
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}
// randomString generates a random string of given length
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// Recover is a middleware that recovers from panics with detailed logging
func Recover(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// Get request ID from context
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log error with stack trace and request context
log.Printf(
"panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
err,
requestID,
r.Method,
r.URL.Path,
r.RemoteAddr,
debug.Stack(),
)
// Determine error type
var message string
var status int
switch e := err.(type) {
case error:
message = e.Error()
if isClientError(e) {
status = http.StatusBadRequest
} else {
status = http.StatusInternalServerError
}
case string:
message = e
status = http.StatusInternalServerError
default:
message = "Internal Server Error"
status = http.StatusInternalServerError
}
ErrorResponse(w, status, message)
}
}()
next.ServeHTTP(w, r)
})
}
// isClientError checks if error should be treated as client error
func isClientError(err error) bool {
// Add more client error types as needed
return strings.Contains(err.Error(), "validation") ||
strings.Contains(err.Error(), "invalid") ||
strings.Contains(err.Error(), "missing")
}
// CORS is a middleware that handles CORS
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// Timeout is a middleware that safely handles request timeouts
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), duration)
defer cancel()
// Use buffered channels to prevent goroutine leaks
done := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
// Track request processing in goroutine
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(w, r.WithContext(ctx))
done <- struct{}{}
}()
// Wait for completion, timeout or panic
select {
case <-done:
return
case p := <-panicChan:
panic(p) // Re-throw panic to be caught by Recover middleware
case <-ctx.Done():
// Get request context for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log timeout details
log.Printf(
"request_timeout: request_id=%s method=%s path=%s timeout=%s",
requestID,
r.Method,
r.URL.Path,
duration.String(),
)
// Send timeout response
ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
"Request timed out after %s", duration.String(),
))
}
})
}
}
// Auth is a middleware that validates JWT tokens with enhanced security
func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get request ID for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
return
}
// Validate header format
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer <token>'")
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtSecret), nil
})
if err != nil {
log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
if !token.Valid {
log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
// Validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
return
}
// Check required claims
userID, ok := claims["user_id"].(string)
if !ok || userID == "" {
log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
return
}
// Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
return
}
}
// Check required roles if specified
if len(requiredRoles) > 0 {
roles, ok := claims["roles"].([]interface{})
if !ok {
log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
return
}
hasRequiredRole := false
for _, requiredRole := range requiredRoles {
for _, role := range roles {
if r, ok := role.(string); ok && r == requiredRole {
hasRequiredRole = true
break
}
}
}
if !hasRequiredRole {
log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
return
}
}
// Add claims to context
ctx := r.Context()
ctx = context.WithValue(ctx, "user_id", userID)
ctx = context.WithValue(ctx, "claims", claims)
log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// responseWriter is a custom response writer that captures the status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// GetUserID gets the user ID from the request context
func GetUserID(r *http.Request) (string, error) {
userID, ok := r.Context().Value("user_id").(string)
if !ok {
return "", fmt.Errorf("user ID not found in context")
}
return userID, nil
}

View file

@ -0,0 +1,50 @@
// app.js
App({
onLaunch() {
// 检查更新
if (wx.canIUse('getUpdateManager')) {
const updateManager = wx.getUpdateManager();
updateManager.onCheckForUpdate(function (res) {
if (res.hasUpdate) {
updateManager.onUpdateReady(function () {
wx.showModal({
title: '更新提示',
content: '新版本已经准备好,是否重启应用?',
success: function (res) {
if (res.confirm) {
updateManager.applyUpdate();
}
}
});
});
updateManager.onUpdateFailed(function () {
wx.showModal({
title: '更新提示',
content: '新版本下载失败,请检查网络后重试',
showCancel: false
});
});
}
});
}
// 获取系统信息
try {
const systemInfo = wx.getSystemInfoSync();
this.globalData.systemInfo = systemInfo;
// 计算安全区域
const { screenHeight, safeArea } = systemInfo;
this.globalData.safeAreaBottom = screenHeight - safeArea.bottom;
} catch (e) {
console.error('获取系统信息失败', e);
}
},
globalData: {
userInfo: null,
systemInfo: {},
safeAreaBottom: 0
}
});

View file

@ -0,0 +1,23 @@
{
"pages": [
"pages/login/login",
"pages/otp-list/index",
"pages/otp-add/index"
],
"window": {
"backgroundTextStyle": "light",
"navigationBarBackgroundColor": "#fff",
"navigationBarTitleText": "OTPM",
"navigationBarTextStyle": "black",
"backgroundColor": "#F8F8F8"
},
"permission": {
"scope.camera": {
"desc": "需要使用相机扫描二维码"
}
},
"usingComponents": {},
"style": "v2",
"sitemapLocation": "sitemap.json",
"lazyCodeLoading": "requiredComponents"
}

View file

@ -0,0 +1,238 @@
/**app.wxss**/
page {
--primary-color: #1890ff;
--danger-color: #ff4d4f;
--success-color: #52c41a;
--warning-color: #faad14;
--text-color: #333333;
--text-color-secondary: #666666;
--text-color-light: #999999;
--border-color: #e8e8e8;
--background-color: #f8f8f8;
--border-radius: 8rpx;
--safe-area-bottom: env(safe-area-inset-bottom);
font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica,
Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei',
sans-serif;
font-size: 28rpx;
line-height: 1.5;
color: var(--text-color);
background-color: var(--background-color);
}
/* 清除默认样式 */
button {
padding: 0;
margin: 0;
background: none;
border: none;
text-align: left;
line-height: inherit;
overflow: visible;
}
button::after {
border: none;
}
/* 通用样式类 */
.container {
min-height: 100vh;
box-sizing: border-box;
}
.safe-area-bottom {
padding-bottom: var(--safe-area-bottom);
}
.flex-center {
display: flex;
align-items: center;
justify-content: center;
}
.flex-between {
display: flex;
align-items: center;
justify-content: space-between;
}
.flex-column {
display: flex;
flex-direction: column;
}
.text-primary {
color: var(--primary-color);
}
.text-danger {
color: var(--danger-color);
}
.text-success {
color: var(--success-color);
}
.text-warning {
color: var(--warning-color);
}
.text-secondary {
color: var(--text-color-secondary);
}
.text-light {
color: var(--text-color-light);
}
.text-center {
text-align: center;
}
.text-left {
text-align: left;
}
.text-right {
text-align: right;
}
.text-bold {
font-weight: bold;
}
.text-ellipsis {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.bg-white {
background-color: #ffffff;
}
.bg-primary {
background-color: var(--primary-color);
}
.bg-danger {
background-color: var(--danger-color);
}
.bg-success {
background-color: var(--success-color);
}
.bg-warning {
background-color: var(--warning-color);
}
.shadow {
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.rounded {
border-radius: var(--border-radius);
}
.border {
border: 2rpx solid var(--border-color);
}
.border-top {
border-top: 2rpx solid var(--border-color);
}
.border-bottom {
border-bottom: 2rpx solid var(--border-color);
}
/* 动画类 */
.fade-in {
animation: fadeIn 0.3s ease-in-out;
}
.fade-out {
animation: fadeOut 0.3s ease-in-out;
}
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
@keyframes fadeOut {
from {
opacity: 1;
}
to {
opacity: 0;
}
}
/* 间距类 */
.m-0 { margin: 0; }
.m-1 { margin: 10rpx; }
.m-2 { margin: 20rpx; }
.m-3 { margin: 30rpx; }
.m-4 { margin: 40rpx; }
.mt-0 { margin-top: 0; }
.mt-1 { margin-top: 10rpx; }
.mt-2 { margin-top: 20rpx; }
.mt-3 { margin-top: 30rpx; }
.mt-4 { margin-top: 40rpx; }
.mb-0 { margin-bottom: 0; }
.mb-1 { margin-bottom: 10rpx; }
.mb-2 { margin-bottom: 20rpx; }
.mb-3 { margin-bottom: 30rpx; }
.mb-4 { margin-bottom: 40rpx; }
.ml-0 { margin-left: 0; }
.ml-1 { margin-left: 10rpx; }
.ml-2 { margin-left: 20rpx; }
.ml-3 { margin-left: 30rpx; }
.ml-4 { margin-left: 40rpx; }
.mr-0 { margin-right: 0; }
.mr-1 { margin-right: 10rpx; }
.mr-2 { margin-right: 20rpx; }
.mr-3 { margin-right: 30rpx; }
.mr-4 { margin-right: 40rpx; }
.p-0 { padding: 0; }
.p-1 { padding: 10rpx; }
.p-2 { padding: 20rpx; }
.p-3 { padding: 30rpx; }
.p-4 { padding: 40rpx; }
.pt-0 { padding-top: 0; }
.pt-1 { padding-top: 10rpx; }
.pt-2 { padding-top: 20rpx; }
.pt-3 { padding-top: 30rpx; }
.pt-4 { padding-top: 40rpx; }
.pb-0 { padding-bottom: 0; }
.pb-1 { padding-bottom: 10rpx; }
.pb-2 { padding-bottom: 20rpx; }
.pb-3 { padding-bottom: 30rpx; }
.pb-4 { padding-bottom: 40rpx; }
.pl-0 { padding-left: 0; }
.pl-1 { padding-left: 10rpx; }
.pl-2 { padding-left: 20rpx; }
.pl-3 { padding-left: 30rpx; }
.pl-4 { padding-left: 40rpx; }
.pr-0 { padding-right: 0; }
.pr-1 { padding-right: 10rpx; }
.pr-2 { padding-right: 20rpx; }
.pr-3 { padding-right: 30rpx; }
.pr-4 { padding-right: 40rpx; }

View file

@ -0,0 +1,48 @@
// login.js
import { wxLogin } from '../../services/auth';
Page({
data: {
loading: false
},
onLoad() {
// 页面加载时检查是否已经登录
const token = wx.getStorageSync('token');
if (token) {
this.redirectToHome();
}
},
// 处理登录按钮点击
handleLogin() {
if (this.data.loading) return;
this.setData({ loading: true });
wxLogin()
.then(() => {
wx.showToast({
title: '登录成功',
icon: 'success'
});
this.redirectToHome();
})
.catch(err => {
wx.showToast({
title: err.message || '登录失败',
icon: 'none'
});
})
.finally(() => {
this.setData({ loading: false });
});
},
// 跳转到首页
redirectToHome() {
wx.reLaunch({
url: '/pages/otp-list/index'
});
}
});

View file

@ -0,0 +1,3 @@
{
"usingComponents": {}
}

View file

@ -0,0 +1,30 @@
<!-- login.wxml -->
<view class="container">
<view class="logo-container">
<image class="logo" src="/assets/images/logo.png" mode="aspectFit"></image>
<text class="app-name">OTPM 小程序</text>
</view>
<view class="login-container">
<text class="login-title">欢迎使用 OTPM</text>
<text class="login-subtitle">一次性密码管理工具</text>
<button
class="login-button {{loading ? 'loading' : ''}}"
type="primary"
bindtap="handleLogin"
disabled="{{loading}}"
>
<text wx:if="{{!loading}}">微信一键登录</text>
<view wx:else class="loading-container">
<view class="loading-icon"></view>
<text>登录中...</text>
</view>
</button>
<view class="privacy-policy">
<text>登录即表示您同意</text>
<navigator url="/pages/privacy/index" class="policy-link">《隐私政策》</navigator>
</view>
</view>
</view>

View file

@ -0,0 +1,97 @@
/* login.wxss */
.container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-between;
height: 100vh;
padding: 60rpx 40rpx;
box-sizing: border-box;
background-color: #f8f8f8;
}
.logo-container {
display: flex;
flex-direction: column;
align-items: center;
margin-top: 80rpx;
}
.logo {
width: 180rpx;
height: 180rpx;
margin-bottom: 20rpx;
}
.app-name {
font-size: 36rpx;
font-weight: bold;
color: #333;
}
.login-container {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
margin-bottom: 100rpx;
}
.login-title {
font-size: 48rpx;
font-weight: bold;
color: #333;
margin-bottom: 20rpx;
}
.login-subtitle {
font-size: 28rpx;
color: #666;
margin-bottom: 80rpx;
}
.login-button {
width: 80%;
height: 88rpx;
border-radius: 44rpx;
font-size: 32rpx;
display: flex;
align-items: center;
justify-content: center;
}
.login-button.loading {
background-color: #8cc4ff;
}
.loading-container {
display: flex;
align-items: center;
justify-content: center;
}
.loading-icon {
width: 36rpx;
height: 36rpx;
margin-right: 10rpx;
border: 4rpx solid #ffffff;
border-radius: 50%;
border-top-color: transparent;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.privacy-policy {
margin-top: 40rpx;
font-size: 24rpx;
color: #999;
}
.policy-link {
color: #1890ff;
display: inline;
}

View file

@ -0,0 +1,169 @@
// otp-add/index.js
import { createOTP } from '../../services/otp';
Page({
data: {
form: {
name: '',
issuer: '',
secret: '',
algorithm: 'SHA1',
digits: 6,
period: 30
},
algorithms: ['SHA1', 'SHA256', 'SHA512'],
digitOptions: [6, 8],
periodOptions: [30, 60],
submitting: false,
scanMode: false
},
// 处理输入变化
handleInputChange(e) {
const { field } = e.currentTarget.dataset;
const { value } = e.detail;
this.setData({
[`form.${field}`]: value
});
},
// 处理选择器变化
handlePickerChange(e) {
const { field } = e.currentTarget.dataset;
const { value } = e.detail;
const options = this.data[`${field}Options`] || this.data[field];
const selectedValue = options[value];
this.setData({
[`form.${field}`]: selectedValue
});
},
// 扫描二维码
handleScanQRCode() {
this.setData({ scanMode: true });
wx.scanCode({
scanType: ['qrCode'],
success: (res) => {
try {
// 解析otpauth://协议的URL
const url = res.result;
if (url.startsWith('otpauth://totp/')) {
const parsedUrl = new URL(url);
const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠
// 解析路径中的issuer和name
let issuer = '';
let name = path;
if (path.includes(':')) {
const parts = path.split(':');
issuer = parts[0];
name = parts[1];
}
// 从查询参数中获取其他信息
const secret = parsedUrl.searchParams.get('secret') || '';
const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1';
const digits = parseInt(parsedUrl.searchParams.get('digits') || '6');
const period = parseInt(parsedUrl.searchParams.get('period') || '30');
// 如果查询参数中有issuer优先使用
if (parsedUrl.searchParams.get('issuer')) {
issuer = parsedUrl.searchParams.get('issuer');
}
this.setData({
form: {
name,
issuer,
secret,
algorithm,
digits,
period
}
});
wx.showToast({
title: '二维码解析成功',
icon: 'success'
});
} else {
wx.showToast({
title: '不支持的二维码格式',
icon: 'none'
});
}
} catch (err) {
wx.showToast({
title: '二维码解析失败',
icon: 'none'
});
}
},
fail: () => {
wx.showToast({
title: '扫描取消',
icon: 'none'
});
},
complete: () => {
this.setData({ scanMode: false });
}
});
},
// 提交表单
handleSubmit() {
const { form } = this.data;
// 表单验证
if (!form.name) {
wx.showToast({
title: '请输入名称',
icon: 'none'
});
return;
}
if (!form.secret) {
wx.showToast({
title: '请输入密钥',
icon: 'none'
});
return;
}
this.setData({ submitting: true });
createOTP(form)
.then(() => {
wx.showToast({
title: '添加成功',
icon: 'success'
});
// 返回上一页
setTimeout(() => {
wx.navigateBack();
}, 1500);
})
.catch(err => {
wx.showToast({
title: err.message || '添加失败',
icon: 'none'
});
})
.finally(() => {
this.setData({ submitting: false });
});
},
// 取消
handleCancel() {
wx.navigateBack();
}
});

View file

@ -0,0 +1,3 @@
{
"usingComponents": {}
}

View file

@ -0,0 +1,119 @@
<!-- otp-add/index.wxml -->
<view class="container">
<view class="header">
<text class="title">添加OTP</text>
</view>
<view class="form-container">
<view class="form-group">
<text class="form-label">名称 <text class="required">*</text></text>
<input
class="form-input"
placeholder="请输入OTP名称"
value="{{form.name}}"
bindinput="handleInputChange"
data-field="name"
/>
</view>
<view class="form-group">
<text class="form-label">发行方</text>
<input
class="form-input"
placeholder="请输入发行方名称"
value="{{form.issuer}}"
bindinput="handleInputChange"
data-field="issuer"
/>
</view>
<view class="form-group">
<text class="form-label">密钥 <text class="required">*</text></text>
<view class="secret-input-container">
<input
class="form-input"
placeholder="请输入密钥或扫描二维码"
value="{{form.secret}}"
bindinput="handleInputChange"
data-field="secret"
/>
<view class="scan-button" bindtap="handleScanQRCode" wx:if="{{!scanMode}}">
<text class="scan-icon">🔍</text>
</view>
<view class="scanning-indicator" wx:else>
<view class="scanning-spinner"></view>
</view>
</view>
</view>
<view class="form-group">
<text class="form-label">算法</text>
<picker
mode="selector"
range="{{algorithms}}"
value="{{algorithms.indexOf(form.algorithm)}}"
bindchange="handlePickerChange"
data-field="algorithm"
>
<view class="picker-view">
<text>{{form.algorithm}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
<view class="form-row">
<view class="form-group half">
<text class="form-label">位数</text>
<picker
mode="selector"
range="{{digitOptions}}"
value="{{digitOptions.indexOf(form.digits)}}"
bindchange="handlePickerChange"
data-field="digits"
>
<view class="picker-view">
<text>{{form.digits}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
<view class="form-group half">
<text class="form-label">周期(秒)</text>
<picker
mode="selector"
range="{{periodOptions}}"
value="{{periodOptions.indexOf(form.period)}}"
bindchange="handlePickerChange"
data-field="period"
>
<view class="picker-view">
<text>{{form.period}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
</view>
</view>
<view class="button-group">
<button
class="cancel-button"
bindtap="handleCancel"
disabled="{{submitting}}"
>取消</button>
<button
class="submit-button {{submitting ? 'loading' : ''}}"
bindtap="handleSubmit"
disabled="{{submitting}}"
>
<text wx:if="{{!submitting}}">保存</text>
<view wx:else class="loading-container">
<view class="loading-icon"></view>
<text>保存中...</text>
</view>
</button>
</view>
</view>

View file

@ -0,0 +1,176 @@
/* otp-add/index.wxss */
.container {
min-height: 100vh;
background-color: #f8f8f8;
padding: 0 0 40rpx 0;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 40rpx 32rpx;
background-color: #ffffff;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
}
.title {
font-size: 36rpx;
font-weight: bold;
color: #333333;
}
.form-container {
background-color: #ffffff;
padding: 32rpx;
margin: 32rpx;
border-radius: 16rpx;
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.form-group {
margin-bottom: 32rpx;
}
.form-row {
display: flex;
justify-content: space-between;
}
.form-group.half {
width: 48%;
}
.form-label {
display: block;
font-size: 28rpx;
color: #666666;
margin-bottom: 12rpx;
}
.required {
color: #ff4d4f;
}
.form-input {
width: 100%;
height: 80rpx;
background-color: #f5f5f5;
border-radius: 8rpx;
padding: 0 24rpx;
font-size: 28rpx;
color: #333333;
box-sizing: border-box;
}
.secret-input-container {
position: relative;
}
.scan-button {
position: absolute;
right: 20rpx;
top: 50%;
transform: translateY(-50%);
width: 60rpx;
height: 60rpx;
display: flex;
align-items: center;
justify-content: center;
}
.scan-icon {
font-size: 40rpx;
color: #1890ff;
}
.scanning-indicator {
position: absolute;
right: 20rpx;
top: 50%;
transform: translateY(-50%);
width: 40rpx;
height: 40rpx;
}
.scanning-spinner {
width: 40rpx;
height: 40rpx;
border: 4rpx solid #f3f3f3;
border-top: 4rpx solid #1890ff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.picker-view {
width: 100%;
height: 80rpx;
background-color: #f5f5f5;
border-radius: 8rpx;
padding: 0 24rpx;
font-size: 28rpx;
color: #333333;
display: flex;
align-items: center;
justify-content: space-between;
box-sizing: border-box;
}
.picker-arrow {
font-size: 24rpx;
color: #999999;
}
.button-group {
display: flex;
justify-content: space-between;
padding: 32rpx;
}
.cancel-button, .submit-button {
width: 48%;
height: 88rpx;
border-radius: 44rpx;
font-size: 32rpx;
display: flex;
align-items: center;
justify-content: center;
}
.cancel-button {
background-color: #f5f5f5;
color: #666666;
}
.submit-button {
background-color: #1890ff;
color: #ffffff;
}
.submit-button.loading {
background-color: #8cc4ff;
}
.loading-container {
display: flex;
align-items: center;
justify-content: center;
}
.loading-icon {
width: 36rpx;
height: 36rpx;
margin-right: 10rpx;
border: 4rpx solid #ffffff;
border-radius: 50%;
border-top-color: transparent;
animation: spin 1s linear infinite;
}

View file

@ -0,0 +1,213 @@
// otp-list/index.js
import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp';
import { checkLoginStatus } from '../../services/auth';
Page({
data: {
otpList: [],
loading: true,
refreshing: false
},
onLoad() {
this.checkLogin();
},
onShow() {
// 每次页面显示时刷新OTP列表
if (!this.data.loading) {
this.fetchOTPList();
}
},
// 下拉刷新
onPullDownRefresh() {
this.setData({ refreshing: true });
this.fetchOTPList().finally(() => {
wx.stopPullDownRefresh();
this.setData({ refreshing: false });
});
},
// 检查登录状态
checkLogin() {
checkLoginStatus().then(isLoggedIn => {
if (isLoggedIn) {
this.fetchOTPList();
} else {
wx.redirectTo({
url: '/pages/login/login'
});
}
});
},
// 获取OTP列表
fetchOTPList() {
this.setData({ loading: true });
return getOTPList()
.then(res => {
if (res.data && Array.isArray(res.data)) {
this.setData({
otpList: res.data,
loading: false
});
// 获取每个OTP的当前验证码
this.refreshOTPCodes();
}
})
.catch(err => {
wx.showToast({
title: '获取OTP列表失败',
icon: 'none'
});
this.setData({ loading: false });
});
},
// 刷新所有OTP的验证码
refreshOTPCodes() {
const { otpList } = this.data;
// 为每个OTP获取当前验证码
const promises = otpList.map(otp => {
return getOTPCode(otp.id)
.then(res => {
if (res.data && res.data.code) {
return {
id: otp.id,
code: res.data.code,
expiresIn: res.data.expires_in || 30
};
}
return null;
})
.catch(() => null);
});
Promise.all(promises).then(results => {
const updatedList = [...this.data.otpList];
results.forEach(result => {
if (result) {
const index = updatedList.findIndex(otp => otp.id === result.id);
if (index !== -1) {
updatedList[index] = {
...updatedList[index],
currentCode: result.code,
expiresIn: result.expiresIn
};
}
}
});
this.setData({ otpList: updatedList });
// 设置定时器,每秒更新倒计时
this.startCountdown();
});
},
// 开始倒计时
startCountdown() {
// 清除之前的定时器
if (this.countdownTimer) {
clearInterval(this.countdownTimer);
}
// 创建新的定时器,每秒更新一次
this.countdownTimer = setInterval(() => {
const { otpList } = this.data;
let needRefresh = false;
const updatedList = otpList.map(otp => {
if (!otp.countdown) {
otp.countdown = otp.expiresIn || 30;
}
otp.countdown -= 1;
// 如果倒计时结束,标记需要刷新
if (otp.countdown <= 0) {
needRefresh = true;
}
return otp;
});
this.setData({ otpList: updatedList });
// 如果有OTP需要刷新重新获取验证码
if (needRefresh) {
this.refreshOTPCodes();
}
}, 1000);
},
// 添加新的OTP
handleAddOTP() {
wx.navigateTo({
url: '/pages/otp-add/index'
});
},
// 编辑OTP
handleEditOTP(e) {
const { id } = e.currentTarget.dataset;
wx.navigateTo({
url: `/pages/otp-edit/index?id=${id}`
});
},
// 删除OTP
handleDeleteOTP(e) {
const { id, name } = e.currentTarget.dataset;
wx.showModal({
title: '确认删除',
content: `确定要删除 ${name} 吗?`,
confirmColor: '#ff4d4f',
success: (res) => {
if (res.confirm) {
deleteOTP(id)
.then(() => {
wx.showToast({
title: '删除成功',
icon: 'success'
});
this.fetchOTPList();
})
.catch(err => {
wx.showToast({
title: '删除失败',
icon: 'none'
});
});
}
}
});
},
// 复制验证码
handleCopyCode(e) {
const { code } = e.currentTarget.dataset;
wx.setClipboardData({
data: code,
success: () => {
wx.showToast({
title: '验证码已复制',
icon: 'success'
});
}
});
},
onUnload() {
// 页面卸载时清除定时器
if (this.countdownTimer) {
clearInterval(this.countdownTimer);
}
}
});

View file

@ -0,0 +1,3 @@
{
"usingComponents": {}
}

View file

@ -0,0 +1,59 @@
<!-- otp-list/index.wxml -->
<view class="container">
<view class="header">
<text class="title">我的OTP列表</text>
<view class="add-button" bindtap="handleAddOTP">
<text class="add-icon">+</text>
</view>
</view>
<!-- 加载中 -->
<view class="loading-container" wx:if="{{loading}}">
<view class="loading-spinner"></view>
<text class="loading-text">加载中...</text>
</view>
<!-- OTP列表 -->
<view class="otp-list" wx:else>
<block wx:if="{{otpList.length > 0}}">
<view class="otp-item" wx:for="{{otpList}}" wx:key="id">
<view class="otp-info">
<view class="otp-name-row">
<text class="otp-name">{{item.name}}</text>
<text class="otp-issuer">{{item.issuer}}</text>
</view>
<view class="otp-code-row" bindtap="handleCopyCode" data-code="{{item.currentCode}}">
<text class="otp-code">{{item.currentCode || '******'}}</text>
<text class="copy-hint">点击复制</text>
</view>
<view class="otp-countdown">
<progress
percent="{{(item.countdown / item.expiresIn) * 100}}"
stroke-width="3"
activeColor="{{item.countdown < 10 ? '#ff4d4f' : '#1890ff'}}"
backgroundColor="#e9e9e9"
/>
<text class="countdown-text">{{item.countdown || 0}}s</text>
</view>
</view>
<view class="otp-actions">
<view class="action-button edit" bindtap="handleEditOTP" data-id="{{item.id}}">
<text class="action-icon">✎</text>
</view>
<view class="action-button delete" bindtap="handleDeleteOTP" data-id="{{item.id}}" data-name="{{item.name}}">
<text class="action-icon">✕</text>
</view>
</view>
</view>
</block>
<!-- 空状态 -->
<view class="empty-state" wx:else>
<image class="empty-image" src="/assets/images/empty.png" mode="aspectFit"></image>
<text class="empty-text">暂无OTP点击右上角添加</text>
</view>
</view>
</view>

View file

@ -0,0 +1,201 @@
/* otp-list/index.wxss */
.container {
min-height: 100vh;
background-color: #f8f8f8;
padding: 0 0 40rpx 0;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 40rpx 32rpx;
background-color: #ffffff;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
}
.title {
font-size: 36rpx;
font-weight: bold;
color: #333333;
}
.add-button {
width: 64rpx;
height: 64rpx;
border-radius: 32rpx;
background-color: #1890ff;
display: flex;
align-items: center;
justify-content: center;
box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3);
}
.add-icon {
color: #ffffff;
font-size: 40rpx;
line-height: 1;
}
/* 加载状态 */
.loading-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 120rpx 0;
}
.loading-spinner {
width: 64rpx;
height: 64rpx;
border: 6rpx solid #f3f3f3;
border-top: 6rpx solid #1890ff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
.loading-text {
margin-top: 20rpx;
font-size: 28rpx;
color: #999999;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* OTP列表 */
.otp-list {
padding: 20rpx 32rpx;
}
.otp-item {
background-color: #ffffff;
border-radius: 16rpx;
padding: 32rpx;
margin-bottom: 20rpx;
display: flex;
justify-content: space-between;
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.otp-info {
flex: 1;
margin-right: 20rpx;
}
.otp-name-row {
display: flex;
align-items: center;
margin-bottom: 16rpx;
}
.otp-name {
font-size: 32rpx;
font-weight: bold;
color: #333333;
margin-right: 16rpx;
}
.otp-issuer {
font-size: 24rpx;
color: #666666;
background-color: #f5f5f5;
padding: 4rpx 12rpx;
border-radius: 8rpx;
}
.otp-code-row {
display: flex;
align-items: center;
margin-bottom: 20rpx;
}
.otp-code {
font-size: 44rpx;
font-family: monospace;
font-weight: bold;
color: #1890ff;
letter-spacing: 4rpx;
margin-right: 16rpx;
}
.copy-hint {
font-size: 24rpx;
color: #999999;
}
.otp-countdown {
position: relative;
width: 100%;
}
.countdown-text {
position: absolute;
right: 0;
top: -30rpx;
font-size: 24rpx;
color: #999999;
}
/* OTP操作按钮 */
.otp-actions {
display: flex;
flex-direction: column;
justify-content: space-between;
}
.action-button {
width: 56rpx;
height: 56rpx;
border-radius: 28rpx;
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 16rpx;
}
.action-button.edit {
background-color: #f0f7ff;
}
.action-button.delete {
background-color: #fff1f0;
}
.action-icon {
font-size: 32rpx;
}
.edit .action-icon {
color: #1890ff;
}
.delete .action-icon {
color: #ff4d4f;
}
/* 空状态 */
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 120rpx 0;
}
.empty-image {
width: 240rpx;
height: 240rpx;
margin-bottom: 32rpx;
}
.empty-text {
font-size: 28rpx;
color: #999999;
}

View file

@ -0,0 +1,47 @@
{
"description": "项目配置文件",
"packOptions": {
"ignore": [],
"include": []
},
"miniprogramRoot": "",
"compileType": "miniprogram",
"projectname": "OTPM",
"setting": {
"useCompilerPlugins": [
"sass"
],
"babelSetting": {
"ignore": [],
"disablePlugins": [],
"outputPath": ""
},
"es6": true,
"enhance": true,
"minified": true,
"postcss": true,
"minifyWXSS": true,
"minifyWXML": true,
"uglifyFileName": true,
"packNpmManually": false,
"packNpmRelationList": [],
"ignoreUploadUnusedFiles": true,
"compileWorklet": false,
"uploadWithSourceMap": true,
"localPlugins": false,
"disableUseStrict": false,
"condition": false,
"swc": false,
"disableSWC": true
},
"simulatorType": "wechat",
"simulatorPluginLibVersion": {},
"condition": {},
"srcMiniprogramRoot": "",
"appid": "wxb6599459668b6b55",
"libVersion": "2.30.2",
"editorSetting": {
"tabIndent": "insertSpaces",
"tabSize": 2
}
}

View file

@ -0,0 +1,23 @@
{
"libVersion": "3.8.5",
"projectname": "OTPM",
"condition": {},
"setting": {
"urlCheck": true,
"coverView": false,
"lazyloadPlaceholderEnable": false,
"skylineRenderEnable": false,
"preloadBackgroundData": false,
"autoAudits": false,
"useApiHook": true,
"useApiHostProcess": true,
"showShadowRootInWxmlPanel": false,
"useStaticServer": false,
"useLanDebug": false,
"showES6CompileOption": false,
"compileHotReLoad": true,
"checkInvalidKey": true,
"ignoreDevUnusedFiles": true,
"bigPackageSizeSupport": false
}
}

View file

@ -0,0 +1,84 @@
// auth.js - 认证相关服务
import request from '../utils/request';
/**
* 微信登录
* 1. 调用wx.login获取code
* 2. 发送code到服务端换取token和openid
* 3. 保存token和openid到本地存储
*/
export const wxLogin = () => {
return new Promise((resolve, reject) => {
wx.login({
success: (res) => {
if (res.code) {
// 发送code到服务端
request({
url: '/login',
method: 'POST',
data: {
code: res.code
}
}).then(response => {
// 保存token和openid
if (response.data && response.data.token && response.data.openid) {
wx.setStorageSync('token', response.data.token);
wx.setStorageSync('openid', response.data.openid);
resolve(response.data);
} else {
reject(new Error('登录失败,服务器返回数据格式错误'));
}
}).catch(err => {
reject(err);
});
} else {
reject(new Error('登录失败获取code失败: ' + res.errMsg));
}
},
fail: (err) => {
reject(new Error('微信登录失败: ' + err.errMsg));
}
});
});
};
/**
* 检查登录状态
* 1. 检查本地是否有token和openid
* 2. 如果有验证token是否有效
* 3. 如果无效清除本地存储并返回false
*/
export const checkLoginStatus = () => {
return new Promise((resolve, reject) => {
const token = wx.getStorageSync('token');
const openid = wx.getStorageSync('openid');
if (!token || !openid) {
resolve(false);
return;
}
// 验证token有效性
request({
url: '/verify-token',
method: 'POST'
}).then(() => {
resolve(true);
}).catch(() => {
// token无效清除本地存储
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
resolve(false);
});
});
};
/**
* 退出登录
*/
export const logout = () => {
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
return Promise.resolve();
};

View file

@ -0,0 +1,119 @@
// otp.js - OTP相关服务
import request from '../utils/request';
/**
* 创建新的OTP
* @param {Object} params - 创建OTP的参数
* @param {string} params.name - OTP名称
* @param {string} params.issuer - 发行方
* @param {string} params.secret - 密钥
* @param {string} params.algorithm - 算法默认为SHA1
* @param {number} params.digits - 位数默认为6
* @param {number} params.period - 周期默认为30秒
* @returns {Promise} - 返回创建结果
*/
export const createOTP = (params) => {
if (!params || !params.secret) {
return Promise.reject(new Error('缺少必要的参数: secret'));
}
return request({
url: '/otp',
method: 'POST',
data: {
name: params.name || '',
issuer: params.issuer || '',
secret: params.secret,
algorithm: params.algorithm || 'SHA1',
digits: params.digits || 6,
period: params.period || 30
}
}).catch(err => {
console.error('创建OTP失败:', err);
throw new Error('创建OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 获取用户所有OTP列表
* @returns {Promise} - 返回OTP列表
*/
export const getOTPList = () => {
return request({
url: '/otp',
method: 'GET'
}).catch(err => {
console.error('获取OTP列表失败:', err);
throw new Error('获取OTP列表失败: ' + (err.message || '未知错误'));
});
};
/**
* 获取指定OTP的当前验证码
* @param {string} id - OTP的ID
* @returns {Promise} - 返回当前验证码
*/
export const getOTPCode = (id) => {
if (!id) {
return Promise.reject(new Error('缺少必要的参数: id'));
}
return request({
url: `/otp/${id}/code`,
method: 'GET'
}).catch(err => {
console.error('获取OTP代码失败:', err);
throw new Error('获取OTP代码失败: ' + (err.message || '未知错误'));
});
};
/**
* 验证OTP
* @param {string} id - OTP的ID
* @param {string} code - 用户输入的验证码
* @returns {Promise} - 返回验证结果
*/
export const verifyOTP = (id, code) => {
if (!id || !code) {
return Promise.reject(new Error('缺少必要的参数: id或code'));
}
return request({
url: `/otp/${id}/verify`,
method: 'POST',
data: { code }
}).catch(err => {
console.error('验证OTP失败:', err);
throw new Error('验证OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 更新OTP信息
* @param {string} id - OTP的ID
* @param {Object} params - 更新的参数
* @returns {Promise} - 返回更新结果
*/
export const updateOTP = (id, params) => {
if (!id || !params) {
return Promise.reject(new Error('缺少必要的参数: id或params'));
}
return request({
url: `/otp/${id}`,
method: 'PUT',
data: params
}).catch(err => {
console.error('更新OTP失败:', err);
throw new Error('更新OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 删除OTP
* @param {string} id - OTP的ID
* @returns {Promise} - 返回删除结果
*/
export const deleteOTP = (id) => {
return request({
url: `/otp/${id}`,
method: 'DELETE'
});
};

View file

@ -0,0 +1,58 @@
// request.js - 网络请求工具类
const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名
// 请求拦截器
const request = (options) => {
return new Promise((resolve, reject) => {
const token = wx.getStorageSync('token');
const header = {
'Content-Type': 'application/json',
...options.header
};
// 如果有token添加到请求头
if (token) {
header['Authorization'] = `Bearer ${token}`;
}
wx.request({
url: `${BASE_URL}${options.url}`,
method: options.method || 'GET',
data: options.data,
header: header,
success: (res) => {
// 处理业务错误
if (res.data.code !== 0) {
// token过期直接清除并跳转登录
if (res.statusCode === 401) {
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
reject(new Error('登录已过期,请重新登录'));
return;
}
reject(new Error(res.data.message || '请求失败'));
return;
}
resolve(res.data);
},
fail: reject
});
});
};
// 刷新token
const refreshToken = () => {
return request({
url: '/refresh-token',
method: 'POST'
}).then(res => {
if (res.data && res.data.token) {
wx.setStorageSync('token', res.data.token);
return res.data.token;
}
throw new Error('Failed to refresh token');
});
};
export default request;

66
models/otp.go Normal file
View file

@ -0,0 +1,66 @@
package models
import (
"context"
"time"
)
// OTP represents a TOTP configuration
type OTP struct {
ID int64 `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id" validate:"required"`
OpenID string `json:"openid" db:"openid" validate:"required"`
Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OTPParams represents common OTP parameters used in creation and update
type OTPParams struct {
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
Period int `json:"period" validate:"omitempty,min=30,max=60"`
}
// OTPRepository handles OTP data storage
type OTPRepository struct {
// Add your database connection or ORM here
}
// Create creates a new OTP record
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
// Implement database creation logic
return nil
}
// FindByID finds an OTP by ID and user ID
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
// Implement database lookup logic
return nil, nil
}
// FindAllByUserID finds all OTPs for a user
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
// Implement database query logic
return nil, nil
}
// Update updates an existing OTP record
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
// Implement database update logic
return nil
}
// Delete deletes an OTP record
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
// Implement database deletion logic
return nil
}

114
models/user.go Normal file
View file

@ -0,0 +1,114 @@
package models
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/jmoiron/sqlx"
)
// User represents a user in the system
type User struct {
ID string `db:"id" json:"id"`
OpenID string `db:"openid" json:"openid"`
SessionKey string `db:"session_key" json:"-"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// UserRepository handles user data operations
type UserRepository struct {
db *sqlx.DB
}
// NewUserRepository creates a new UserRepository
func NewUserRepository(db *sqlx.DB) *UserRepository {
return &UserRepository{db: db}
}
// FindByID finds a user by ID
func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
var user User
query := `SELECT * FROM users WHERE id = ?`
err := r.db.GetContext(ctx, &user, query, id)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found: %w", err)
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
return &user, nil
}
// FindByOpenID finds a user by OpenID
func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) {
var user User
query := `SELECT * FROM users WHERE openid = ?`
err := r.db.GetContext(ctx, &user, query, openID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // User not found, but not an error
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
return &user, nil
}
// Create creates a new user
func (r *UserRepository) Create(ctx context.Context, user *User) error {
query := `
INSERT INTO users (id, openid, session_key, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
`
now := time.Now()
user.CreatedAt = now
user.UpdatedAt = now
_, err := r.db.ExecContext(
ctx,
query,
user.ID,
user.OpenID,
user.SessionKey,
user.CreatedAt,
user.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// Update updates an existing user
func (r *UserRepository) Update(ctx context.Context, user *User) error {
query := `
UPDATE users
SET session_key = ?, updated_at = ?
WHERE id = ?
`
user.UpdatedAt = time.Now()
_, err := r.db.ExecContext(
ctx,
query,
user.SessionKey,
user.UpdatedAt,
user.ID,
)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// Delete deletes a user
func (r *UserRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM users WHERE id = ?`
_, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}

332
security/security.go Normal file
View file

@ -0,0 +1,332 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/crypto/argon2"
)
// SecurityService provides security functionality
type SecurityService struct {
config *Config
}
// Config represents security configuration
type Config struct {
// CSRF protection
CSRFTokenLength int
CSRFTokenExpiry time.Duration
CSRFCookieName string
CSRFHeaderName string
CSRFCookieSecure bool
CSRFCookieHTTPOnly bool
CSRFCookieSameSite http.SameSite
// Rate limiting
RateLimitRequests int
RateLimitWindow time.Duration
// Password hashing
Argon2Time uint32
Argon2Memory uint32
Argon2Threads uint8
Argon2KeyLen uint32
Argon2SaltLen uint32
}
// DefaultConfig returns the default security configuration
func DefaultConfig() *Config {
return &Config{
// CSRF protection
CSRFTokenLength: 32,
CSRFTokenExpiry: 24 * time.Hour,
CSRFCookieName: "csrf_token",
CSRFHeaderName: "X-CSRF-Token",
CSRFCookieSecure: true,
CSRFCookieHTTPOnly: true,
CSRFCookieSameSite: http.SameSiteStrictMode,
// Rate limiting
RateLimitRequests: 100,
RateLimitWindow: time.Minute,
// Password hashing
Argon2Time: 1,
Argon2Memory: 64 * 1024,
Argon2Threads: 4,
Argon2KeyLen: 32,
Argon2SaltLen: 16,
}
}
// NewSecurityService creates a new SecurityService
func NewSecurityService(config *Config) *SecurityService {
if config == nil {
config = DefaultConfig()
}
return &SecurityService{
config: config,
}
}
// GenerateCSRFToken generates a CSRF token
func (s *SecurityService) GenerateCSRFToken() (string, error) {
// Generate random bytes
bytes := make([]byte, s.config.CSRFTokenLength)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode as base64
token := base64.StdEncoding.EncodeToString(bytes)
return token, nil
}
// SetCSRFCookie sets a CSRF cookie
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
http.SetCookie(w, &http.Cookie{
Name: s.config.CSRFCookieName,
Value: token,
Path: "/",
Expires: time.Now().Add(s.config.CSRFTokenExpiry),
Secure: s.config.CSRFCookieSecure,
HttpOnly: s.config.CSRFCookieHTTPOnly,
SameSite: s.config.CSRFCookieSameSite,
})
}
// ValidateCSRFToken validates a CSRF token
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
// Get token from cookie
cookie, err := r.Cookie(s.config.CSRFCookieName)
if err != nil {
return false
}
cookieToken := cookie.Value
// Get token from header
headerToken := r.Header.Get(s.config.CSRFHeaderName)
// Compare tokens
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
}
// CSRFMiddleware creates a middleware that validates CSRF tokens
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip for GET, HEAD, OPTIONS, TRACE
if r.Method == http.MethodGet ||
r.Method == http.MethodHead ||
r.Method == http.MethodOptions ||
r.Method == http.MethodTrace {
next.ServeHTTP(w, r)
return
}
// Validate CSRF token
if !s.ValidateCSRFToken(r) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// RateLimiter represents a rate limiter
type RateLimiter struct {
requests map[string][]time.Time
config *Config
}
// NewRateLimiter creates a new RateLimiter
func NewRateLimiter(config *Config) *RateLimiter {
return &RateLimiter{
requests: make(map[string][]time.Time),
config: config,
}
}
// Allow checks if a request is allowed
func (r *RateLimiter) Allow(key string) bool {
now := time.Now()
windowStart := now.Add(-r.config.RateLimitWindow)
// Get requests for key
requests := r.requests[key]
// Filter out old requests
var newRequests []time.Time
for _, t := range requests {
if t.After(windowStart) {
newRequests = append(newRequests, t)
}
}
// Check if rate limit is exceeded
if len(newRequests) >= r.config.RateLimitRequests {
return false
}
// Add current request
newRequests = append(newRequests, now)
r.requests[key] = newRequests
return true
}
// RateLimitMiddleware creates a middleware that limits request rate
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP
ip := getClientIP(r)
// Check if request is allowed
if !limiter.Allow(ip) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// getClientIP gets the client IP address
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, use the first one
ips := strings.Split(xForwardedFor, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xRealIP := r.Header.Get("X-Real-IP")
if xRealIP != "" {
return xRealIP
}
// Use RemoteAddr
return r.RemoteAddr
}
// HashPassword hashes a password using Argon2
func (s *SecurityService) HashPassword(password string) (string, error) {
// Generate salt
salt := make([]byte, s.config.Argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
// Hash password
hash := argon2.IDKey(
[]byte(password),
salt,
s.config.Argon2Time,
s.config.Argon2Memory,
s.config.Argon2Threads,
s.config.Argon2KeyLen,
)
// Encode as base64
saltBase64 := base64.StdEncoding.EncodeToString(salt)
hashBase64 := base64.StdEncoding.EncodeToString(hash)
// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
return fmt.Sprintf(
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
s.config.Argon2Memory,
s.config.Argon2Time,
s.config.Argon2Threads,
saltBase64,
hashBase64,
), nil
}
// VerifyPassword verifies a password against a hash
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
// Parse encoded hash
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
// Extract parameters
if parts[1] != "argon2id" {
return false, fmt.Errorf("unsupported hash algorithm")
}
var memory uint32
var time uint32
var threads uint8
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
if err != nil {
return false, fmt.Errorf("failed to parse hash parameters: %w", err)
}
// Decode salt and hash
salt, err := base64.StdEncoding.DecodeString(parts[4])
if err != nil {
return false, fmt.Errorf("failed to decode salt: %w", err)
}
hash, err := base64.StdEncoding.DecodeString(parts[5])
if err != nil {
return false, fmt.Errorf("failed to decode hash: %w", err)
}
// Hash password with same parameters
newHash := argon2.IDKey(
[]byte(password),
salt,
time,
memory,
threads,
uint32(len(hash)),
)
// Compare hashes
return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
}
// SecureHeadersMiddleware adds security headers to responses
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
next.ServeHTTP(w, r)
})
}
// contextKey is a type for context keys
type contextKey string
// userIDKey is the key for user ID in context
const userIDKey = contextKey("user_id")
// WithUserID adds a user ID to the context
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}
// GetUserID gets the user ID from the context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userIDKey).(string)
return userID, ok
}

190
server/server.go Normal file
View file

@ -0,0 +1,190 @@
package server
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"otpm/config"
"otpm/middleware"
"github.com/julienschmidt/httprouter"
)
// Server represents the HTTP server
type Server struct {
server *http.Server
router *httprouter.Router
config *config.Config
}
// New creates a new server
func New(cfg *config.Config) *Server {
router := httprouter.New()
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: 120 * time.Second,
}
return &Server{
server: server,
router: router,
config: cfg,
}
}
// Start starts the server
func (s *Server) Start() error {
// Apply global middleware in correct order with enhanced error handling
var handler http.Handler = s.router
// Logger should be first to capture all request details
handler = middleware.Logger(handler)
// CORS next to handle pre-flight requests
handler = middleware.CORS(handler)
// Then Timeout to enforce request deadlines
handler = middleware.Timeout(s.config.Server.Timeout)(handler)
// Recover should be outermost to catch any panics
handler = middleware.Recover(handler)
s.server.Handler = handler
// Log server configuration at startup
log.Printf("Server configuration:\n"+
"Address: %s\n"+
"Read Timeout: %v\n"+
"Write Timeout: %v\n"+
"Idle Timeout: %v\n"+
"Request Timeout: %v",
s.server.Addr,
s.server.ReadTimeout,
s.server.WriteTimeout,
s.server.IdleTimeout,
s.config.Server.Timeout,
)
// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
log.Printf("Server starting on %s", s.server.Addr)
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
serverErr <- fmt.Errorf("server error: %w", err)
}
}()
// Wait for interrupt signal or server error
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-serverErr:
return err
case <-quit:
return s.Shutdown()
}
}
// Shutdown gracefully stops the server
func (s *Server) Shutdown() error {
log.Println("Shutting down server...")
ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("graceful shutdown failed: %w", err)
}
log.Println("Server stopped gracefully")
return nil
}
// Router returns the router
func (s *Server) Router() *httprouter.Router {
return s.router
}
// RegisterRoutes registers all routes
func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
s.router.Handle("GET", pattern, handler)
s.router.Handle("POST", pattern, handler)
s.router.Handle("PUT", pattern, handler)
s.router.Handle("DELETE", pattern, handler)
}
}
// RegisterAuthRoutes registers routes that require authentication
func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
// Apply authentication middleware
authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Convert httprouter.Handle to http.HandlerFunc for middleware
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Store params in request context
ctx := context.WithValue(r.Context(), "params", ps)
handler(w, r.WithContext(ctx), ps)
})
// Apply auth middleware
middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
}
s.router.Handle("GET", pattern, authHandler)
s.router.Handle("POST", pattern, authHandler)
s.router.Handle("PUT", pattern, authHandler)
s.router.Handle("DELETE", pattern, authHandler)
}
}
// RegisterHealthCheck registers an enhanced health check endpoint
func (s *Server) RegisterHealthCheck() {
s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
response := map[string]interface{}{
"status": "ok",
"timestamp": time.Now().Format(time.RFC3339),
"version": "1.0.0", // Hardcoded version instead of from config
"system": map[string]interface{}{
"goroutines": runtime.NumGoroutine(),
"memory": getMemoryUsage(),
},
}
// Add database status if configured
if s.config.Database.DSN != "" {
dbStatus := "ok"
response["database"] = dbStatus
}
middleware.SuccessResponse(w, response)
})
}
// getMemoryUsage returns current memory usage in MB
func getMemoryUsage() map[string]interface{} {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return map[string]interface{}{
"alloc_mb": bToMb(m.Alloc),
"total_alloc_mb": bToMb(m.TotalAlloc),
"sys_mb": bToMb(m.Sys),
"num_gc": m.NumGC,
}
}
func bToMb(b uint64) float64 {
return float64(b) / 1024 / 1024
}

230
services/auth.go Normal file
View file

@ -0,0 +1,230 @@
package services
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"otpm/config"
"otpm/models"
)
// WeChatCode2SessionResponse represents the response from WeChat code2session API
type WeChatCode2SessionResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// AuthService handles authentication related operations
type AuthService struct {
config *config.Config
userRepo *models.UserRepository
httpClient *http.Client
}
// NewAuthService creates a new AuthService
func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
return &AuthService{
config: cfg,
userRepo: userRepo,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// LoginWithWeChatCode handles WeChat login
func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
start := time.Now()
// Get OpenID and SessionKey from WeChat
sessionInfo, err := s.getWeChatSession(code)
if err != nil {
log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
return "", fmt.Errorf("failed to get WeChat session: %w", err)
}
log.Printf("WeChat session obtained for code %s (took %v)",
maskCode(code), time.Since(start))
// Find or create user
user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
if err != nil {
log.Printf("User lookup failed for OpenID %s: %v",
maskOpenID(sessionInfo.OpenID), err)
return "", fmt.Errorf("failed to find user: %w", err)
}
if user == nil {
// Create new user
user = &models.User{
ID: uuid.New().String(),
OpenID: sessionInfo.OpenID,
SessionKey: sessionInfo.SessionKey,
}
if err := s.userRepo.Create(ctx, user); err != nil {
log.Printf("User creation failed for OpenID %s: %v",
maskOpenID(sessionInfo.OpenID), err)
return "", fmt.Errorf("failed to create user: %w", err)
}
log.Printf("New user created with ID %s for OpenID %s",
user.ID, maskOpenID(sessionInfo.OpenID))
} else {
// Update session key
user.SessionKey = sessionInfo.SessionKey
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("User update failed for ID %s: %v", user.ID, err)
return "", fmt.Errorf("failed to update user: %w", err)
}
log.Printf("User %s session key updated", user.ID)
}
// Generate JWT token
token, err := s.generateToken(user)
if err != nil {
log.Printf("Token generation failed for user %s: %v", user.ID, err)
return "", fmt.Errorf("failed to generate token: %w", err)
}
log.Printf("WeChat login completed for user %s (total time %v)",
user.ID, time.Since(start))
return token, nil
}
// maskCode masks sensitive parts of WeChat code for logging
func maskCode(code string) string {
if len(code) < 8 {
return "****"
}
return code[:2] + "****" + code[len(code)-2:]
}
// maskOpenID masks sensitive parts of OpenID for logging
func maskOpenID(openID string) string {
if len(openID) < 8 {
return "****"
}
return openID[:2] + "****" + openID[len(openID)-2:]
}
// getWeChatSession calls WeChat's code2session API
func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
url := fmt.Sprintf(
"https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.config.WeChat.AppID,
s.config.WeChat.AppSecret,
code,
)
resp, err := s.httpClient.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to call WeChat API: %w", err)
}
defer resp.Body.Close()
var result WeChatCode2SessionResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
}
if result.ErrCode != 0 {
return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
}
return &result, nil
}
// generateToken generates a JWT token for a user
func (s *AuthService) generateToken(user *models.User) (string, error) {
now := time.Now()
claims := jwt.MapClaims{
"user_id": user.ID,
"exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
"iat": now.Unix(),
"iss": s.config.JWT.Issuer,
"aud": s.config.JWT.Audience,
"token_id": uuid.New().String(), // Unique token ID for tracking
}
// Use stronger signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
log.Printf("Token generated for user %s (expires at %v)",
user.ID, now.Add(s.config.JWT.ExpireDelta))
return signedToken, nil
}
// ValidateToken validates a JWT token with additional checks
func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.config.JWT.Secret), nil
})
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
switch {
case ve.Errors&jwt.ValidationErrorMalformed != 0:
return nil, fmt.Errorf("malformed token")
case ve.Errors&jwt.ValidationErrorExpired != 0:
return nil, fmt.Errorf("token expired")
case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
return nil, fmt.Errorf("token not active yet")
default:
return nil, fmt.Errorf("token validation error: %w", err)
}
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// Additional claims validation
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// Check issuer
if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
return nil, fmt.Errorf("invalid token issuer")
}
// Check audience
if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
return nil, fmt.Errorf("invalid token audience")
}
} else {
return nil, fmt.Errorf("invalid token claims")
}
return token, nil
}
// GetUserFromToken gets user information from a JWT token
func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, fmt.Errorf("user_id not found in token")
}
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
}
return user, nil
}

358
services/otp.go Normal file
View file

@ -0,0 +1,358 @@
package services
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"log"
"strings"
"time"
"otpm/models"
"github.com/google/uuid"
)
// OTPService handles OTP related operations
type OTPService struct {
otpRepo *models.OTPRepository
}
// NewOTPService creates a new OTPService
func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
return &OTPService{
otpRepo: otpRepo,
}
}
// CreateOTP creates a new OTP with performance monitoring and logging
func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
start := time.Now()
// Validate input
if err := s.validateOTPInput(input); err != nil {
log.Printf("OTP validation failed for user %s: %v", userID, err)
return nil, err
}
// Clean and standardize secret
secret := cleanSecret(input.Secret)
// Set defaults for optional fields
algorithm := strings.ToUpper(input.Algorithm)
if algorithm == "" {
algorithm = "SHA1"
}
digits := input.Digits
if digits == 0 {
digits = 6
}
period := input.Period
if period == 0 {
period = 30
}
// Create OTP
otp := &models.OTP{
ID: uuid.New().String(),
UserID: userID,
Name: input.Name,
Issuer: input.Issuer,
Secret: secret,
Algorithm: algorithm,
Digits: digits,
Period: period,
}
if err := s.otpRepo.Create(ctx, otp); err != nil {
log.Printf("Failed to create OTP for user %s: %v", userID, err)
return nil, fmt.Errorf("failed to create OTP: %w", err)
}
// Log successful creation (without exposing secret)
log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
return otp, nil
}
// GetOTP gets an OTP by ID
func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
return otp, nil
}
// ListOTPs lists all OTPs for a user
func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to list OTPs: %w", err)
}
return otps, nil
}
// UpdateOTP updates an OTP
func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
// Get existing OTP
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
// Update fields
if input.Name != "" {
otp.Name = input.Name
}
if input.Issuer != "" {
otp.Issuer = input.Issuer
}
if input.Algorithm != "" {
otp.Algorithm = strings.ToUpper(input.Algorithm)
}
if input.Digits > 0 {
otp.Digits = input.Digits
}
if input.Period > 0 {
otp.Period = input.Period
}
// Validate updated OTP
if err := s.validateOTPInput(models.OTPParams{
Name: otp.Name,
Issuer: otp.Issuer,
Secret: otp.Secret,
Algorithm: otp.Algorithm,
Digits: otp.Digits,
Period: otp.Period,
}); err != nil {
return nil, err
}
if err := s.otpRepo.Update(ctx, otp); err != nil {
return nil, fmt.Errorf("failed to update OTP: %w", err)
}
return otp, nil
}
// DeleteOTP deletes an OTP
func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
return fmt.Errorf("failed to delete OTP: %w", err)
}
return nil
}
// GenerateCode generates a TOTP code with enhanced logging and error handling
func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
start := time.Now()
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
return "", 0, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current time step
now := time.Now().Unix()
timeStep := now / int64(otp.Period)
// Generate code
code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
return "", 0, fmt.Errorf("failed to generate code: %w", err)
}
// Calculate remaining seconds
remainingSeconds := otp.Period - int(now%int64(otp.Period))
// Log successful generation (without actual code)
log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
id, userID, time.Since(start), remainingSeconds)
return code, remainingSeconds, nil
}
// VerifyCode verifies a TOTP code with enhanced security and logging
func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
start := time.Now()
// Basic input validation
if len(code) == 0 {
log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
return false, fmt.Errorf("code is required")
}
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s during verification: %v",
id, userID, err)
return false, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current and adjacent time steps
now := time.Now().Unix()
timeSteps := []int64{
(now - int64(otp.Period)) / int64(otp.Period),
now / int64(otp.Period),
(now + int64(otp.Period)) / int64(otp.Period),
}
// Check code against all time steps
for _, ts := range timeSteps {
expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Code generation failed for time step %d: %v", ts, err)
continue
}
if expectedCode == code {
// Log successful verification
log.Printf("Code verified successfully for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return true, nil
}
}
// Log failed verification attempt
log.Printf("Invalid code provided for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return false, nil
}
// validateOTPInput validates OTP input with detailed error messages
func (s *OTPService) validateOTPInput(input models.OTPParams) error {
if input.Name == "" {
return fmt.Errorf("name is required")
}
if len(input.Name) > 100 {
return fmt.Errorf("name is too long (maximum 100 characters)")
}
if input.Secret == "" {
return fmt.Errorf("secret is required")
}
if !isValidBase32(input.Secret) {
return fmt.Errorf("invalid secret format: must be a valid base32 string")
}
// Secret length check (after base32 decoding)
secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
if len(secretBytes) < 10 {
return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
}
if input.Algorithm != "" {
if !isValidAlgorithm(input.Algorithm) {
return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
}
}
if input.Digits != 0 {
if input.Digits < 6 || input.Digits > 8 {
return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
}
}
if input.Period != 0 {
if input.Period < 30 || input.Period > 60 {
return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
}
}
return nil
}
// Helper functions
func cleanSecret(secret string) string {
// Remove spaces and convert to upper case
secret = strings.TrimSpace(strings.ToUpper(secret))
// Remove any padding characters
return strings.TrimRight(secret, "=")
}
func isValidBase32(s string) bool {
// Try to decode the secret
_, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
return err == nil
}
func isValidAlgorithm(algorithm string) bool {
switch strings.ToUpper(algorithm) {
case "SHA1", "SHA256", "SHA512":
return true
default:
return false
}
}
func getHasher(algorithm string, key []byte) (hash.Hash, error) {
switch strings.ToUpper(algorithm) {
case "SHA1":
return hmac.New(sha1.New, key), nil
case "SHA256":
return hmac.New(sha256.New, key), nil
case "SHA512":
return hmac.New(sha512.New, key), nil
default:
return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
}
}
func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
// Decode secret
secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
if err != nil {
return "", fmt.Errorf("invalid secret: %w", err)
}
// Get initialized HMAC hasher with secret
hasher, err := getHasher(algorithm, secretBytes)
if err != nil {
return "", err
}
// Convert time step to bytes
timeBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
// Calculate HMAC
hasher.Write(timeBytes)
hash := hasher.Sum(nil)
// Get offset
offset := hash[len(hash)-1] & 0xf
// Generate 4-byte code
code := binary.BigEndian.Uint32(hash[offset : offset+4])
code = code & 0x7fffffff
// Get the specified number of digits
code = code % uint32(pow10(digits))
// Format code with leading zeros
return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
}
func pow10(n int) uint32 {
result := uint32(1)
for i := 0; i < n; i++ {
result *= 10
}
return result
}

View file

@ -1,20 +1,65 @@
package utils
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
"github.com/spf13/viper"
)
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
// 返回一个httprouter.Handle函数该函数接受http.ResponseWriter和*http.Request作为参数
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// 调用传入的http.Handler函数将http.ResponseWriter和*http.Request作为参数传递
h(w, r)
}
}
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
return
}
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
secret := viper.GetString("auth.secret")
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(secret), nil
})
if err != nil || !token.Valid {
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
return
}
type contextKey string
// 将 openid 存入上下文
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
next.ServeHTTP(w, r.WithContext(ctx))
}
}
// AesDecrypt 函数用于AES解密
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
//Base64解码
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
return origData, nil
}
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
func PKCS7UnPadding(plantText []byte) []byte {
// 获取密文的长度
length := len(plantText)
// 如果密文长度大于0
if length > 0 {
// 获取最后一个字节的值,即填充的位数
unPadding := int(plantText[length-1])
// 返回去除填充后的密文
return plantText[:(length - unPadding)]
}
// 如果密文长度为0则返回原密文
return plantText
}

316
validator/validator.go Normal file
View file

@ -0,0 +1,316 @@
package validator
import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var (
validate *validator.Validate
// 自定义验证规则
customValidations = map[string]validator.Func{
"otpsecret": validateOTPSecret,
"password": validatePassword,
"issuer": validateIssuer,
"otpauth_uri": validateOTPAuthURI,
"no_xss": validateNoXSS,
}
// 常见的弱密码列表(实际使用时应该使用更完整的列表)
commonPasswords = map[string]bool{
"password123": true,
"12345678": true,
"qwerty123": true,
"admin123": true,
"letmein": true,
"welcome": true,
"password": true,
"admin": true,
}
// 预编译的XSS检测正则表达式
xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
}
// 预编译的正则表达式
base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
upperRegex = regexp.MustCompile(`[A-Z]`)
lowerRegex = regexp.MustCompile(`[a-z]`)
numberRegex = regexp.MustCompile(`[0-9]`)
specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
)
func init() {
validate = validator.New()
// 注册自定义验证规则
for tag, fn := range customValidations {
if err := validate.RegisterValidation(tag, fn); err != nil {
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
}
}
// 使用json tag作为字段名
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// ValidateRequest validates a request body against a struct
func ValidateRequest(r *http.Request, v interface{}) error {
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := validate.Struct(v); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
return NewValidationError(validationErrors)
}
return fmt.Errorf("validation error: %w", err)
}
return nil
}
// ValidationError represents a validation error
type ValidationError struct {
Fields map[string]string `json:"fields"`
}
// Error implements the error interface
func (e *ValidationError) Error() string {
var errors []string
for field, msg := range e.Fields {
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
}
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
}
// NewValidationError creates a new ValidationError from validator.ValidationErrors
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
fields := make(map[string]string)
for _, err := range errors {
fields[err.Field()] = getErrorMessage(err)
}
return &ValidationError{Fields: fields}
}
// getErrorMessage returns a human-readable error message for a validation error
func getErrorMessage(err validator.FieldError) string {
switch err.Tag() {
case "required":
return "此字段为必填项"
case "email":
return "请输入有效的电子邮件地址"
case "min":
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
}
return fmt.Sprintf("必须大于或等于 %s", err.Param())
case "max":
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
}
return fmt.Sprintf("必须小于或等于 %s", err.Param())
case "len":
return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
case "oneof":
return fmt.Sprintf("必须是以下值之一: %s", err.Param())
case "otpsecret":
return "OTP密钥格式无效必须是有效的Base32编码"
case "password":
return "密码必须至少10个字符并包含大写字母、小写字母以及数字或特殊字符"
case "issuer":
return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
case "otpauth_uri":
return "OTP认证URI格式无效"
case "no_xss":
return "输入包含潜在的不安全内容"
case "numeric":
return "必须是数字"
default:
return fmt.Sprintf("验证失败: %s", err.Tag())
}
}
// Custom validation functions
// validateOTPSecret validates an OTP secret
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
if secret == "" {
return false
}
// OTP secret should be base32 encoded
if !base32Regex.MatchString(secret) {
return false
}
// Check length (typical OTP secrets are 16-64 characters)
validLength := len(secret) >= 16 && len(secret) <= 128
return validLength
}
// validatePassword validates a password
func validatePassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
// At least 10 characters long
if len(password) < 10 {
return false
}
// Check if it's a common password
if commonPasswords[strings.ToLower(password)] {
return false
}
// Check character types
hasUpper := upperRegex.MatchString(password)
hasLower := lowerRegex.MatchString(password)
hasNumber := numberRegex.MatchString(password)
hasSpecial := specialRegex.MatchString(password)
// Ensure password has enough complexity
complexity := 0
if hasUpper {
complexity++
}
if hasLower {
complexity++
}
if hasNumber {
complexity++
}
if hasSpecial {
complexity++
}
return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
}
// validateIssuer validates an issuer name
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
if issuer == "" {
return false
}
// Issuer should not contain special characters that could cause problems in URLs
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
validLength := len(issuer) >= 1 && len(issuer) <= 100
return validLength
}
// validateOTPAuthURI validates an otpauth URI
func validateOTPAuthURI(fl validator.FieldLevel) bool {
uri := fl.Field().String()
if uri == "" {
return false
}
// Basic format check for otpauth URI
// Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
return otpauthRegex.MatchString(uri)
}
// validateNoXSS checks if a string contains potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// 检查基本的HTML编码
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// 检查十六进制编码
if strings.Contains(strings.ToLower(value), "\\x3c") || // <
strings.Contains(strings.ToLower(value), "\\x3e") { // >
return false
}
// 检查Unicode编码
if strings.Contains(strings.ToLower(value), "\\u003c") || // <
strings.Contains(strings.ToLower(value), "\\u003e") { // >
return false
}
// 使用预编译的正则表达式检查XSS模式
for _, pattern := range xssPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}
// Request validation structs
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"required,oneof=6 8"`
Period int `json:"period" validate:"required,oneof=30 60"`
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
Period int `json:"period" validate:"omitempty,oneof=30 60"`
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}