diff --git a/api/response.go b/api/response.go
deleted file mode 100644
index 8e24daf..0000000
--- a/api/response.go
+++ /dev/null
@@ -1,149 +0,0 @@
-package api
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
-)
-
-// Common error codes
-const (
- CodeSuccess = 0
- CodeInvalidParams = 400
- CodeUnauthorized = 401
- CodeForbidden = 403
- CodeNotFound = 404
- CodeInternalError = 500
- CodeServiceUnavail = 503
-)
-
-// Error represents an API error
-type Error struct {
- Code int `json:"code"`
- Message string `json:"message"`
-}
-
-// Error implements the error interface
-func (e *Error) Error() string {
- return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
-}
-
-// NewError creates a new API error
-func NewError(code int, message string) *Error {
- return &Error{
- Code: code,
- Message: message,
- }
-}
-
-// Response represents a standard API response
-type Response struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data interface{} `json:"data,omitempty"`
-}
-
-// ResponseWriter wraps common response writing functions
-type ResponseWriter struct {
- http.ResponseWriter
-}
-
-// NewResponseWriter creates a new ResponseWriter
-func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
- return &ResponseWriter{w}
-}
-
-// WriteJSON writes a JSON response
-func (w *ResponseWriter) WriteJSON(code int, data interface{}) error {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(code)
- return json.NewEncoder(w).Encode(data)
-}
-
-// WriteSuccess writes a success response
-func (w *ResponseWriter) WriteSuccess(data interface{}) error {
- return w.WriteJSON(http.StatusOK, Response{
- Code: CodeSuccess,
- Message: "success",
- Data: data,
- })
-}
-
-// WriteError writes an error response
-func (w *ResponseWriter) WriteError(err error) error {
- var apiErr *Error
- if errors.As(err, &apiErr) {
- return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{
- Code: apiErr.Code,
- Message: apiErr.Message,
- })
- }
-
- // Handle unknown errors
- return w.WriteJSON(http.StatusInternalServerError, Response{
- Code: CodeInternalError,
- Message: "Internal Server Error",
- })
-}
-
-// WriteErrorWithCode writes an error response with a specific code
-func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error {
- return w.WriteJSON(getHTTPStatus(code), Response{
- Code: code,
- Message: message,
- })
-}
-
-// getHTTPStatus maps API error codes to HTTP status codes
-func getHTTPStatus(code int) int {
- switch code {
- case CodeSuccess:
- return http.StatusOK
- case CodeInvalidParams:
- return http.StatusBadRequest
- case CodeUnauthorized:
- return http.StatusUnauthorized
- case CodeForbidden:
- return http.StatusForbidden
- case CodeNotFound:
- return http.StatusNotFound
- case CodeServiceUnavail:
- return http.StatusServiceUnavailable
- default:
- return http.StatusInternalServerError
- }
-}
-
-// Common errors
-var (
- ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters")
- ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized")
- ErrForbidden = NewError(CodeForbidden, "Forbidden")
- ErrNotFound = NewError(CodeNotFound, "Resource not found")
- ErrInternalError = NewError(CodeInternalError, "Internal server error")
- ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable")
-)
-
-// ValidationError creates an error for invalid parameters
-func ValidationError(message string) *Error {
- return NewError(CodeInvalidParams, message)
-}
-
-// NotFoundError creates an error for not found resources
-func NotFoundError(resource string) *Error {
- return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource))
-}
-
-// ForbiddenError creates an error for forbidden actions
-func ForbiddenError(message string) *Error {
- return NewError(CodeForbidden, message)
-}
-
-// InternalError creates an error for internal server errors
-func InternalError(err error) *Error {
- if err == nil {
- return ErrInternalError
- }
- return NewError(CodeInternalError, err.Error())
-}
diff --git a/api/validator.go b/api/validator.go
deleted file mode 100644
index a18a552..0000000
--- a/api/validator.go
+++ /dev/null
@@ -1,152 +0,0 @@
-package api
-
-import (
- "regexp"
- "strings"
-
- "github.com/go-playground/validator/v10"
-)
-
-// Validate is a global validator instance
-var Validate = validator.New()
-
-// RegisterCustomValidations registers custom validation functions
-func RegisterCustomValidations() {
- // Register custom validation for issuer
- Validate.RegisterValidation("issuer", validateIssuer)
-
- // Register custom validation for XSS prevention
- Validate.RegisterValidation("no_xss", validateNoXSS)
-
- // Register custom validation for OTP secret
- Validate.RegisterValidation("otpsecret", validateOTPSecret)
-}
-
-// validateOTPSecret validates that the OTP secret is in valid base32 format
-func validateOTPSecret(fl validator.FieldLevel) bool {
- secret := fl.Field().String()
-
- // Check if the secret is not empty
- if secret == "" {
- return false
- }
-
- // Check if the secret is in base32 format (A-Z, 2-7)
- base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
- if !base32Regex.MatchString(secret) {
- return false
- }
-
- // Check if the length is valid (must be at least 16 characters)
- if len(secret) < 16 || len(secret) > 128 {
- return false
- }
-
- return true
-}
-
-// validateIssuer validates that the issuer field contains only allowed characters
-func validateIssuer(fl validator.FieldLevel) bool {
- issuer := fl.Field().String()
-
- // Empty issuer is valid (since it's optional)
- if issuer == "" {
- return true
- }
-
- // Allow alphanumeric characters, spaces, and common punctuation
- issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
-
-import (
- "regexp"
- "strings"
-
- "github.com/go-playground/validator/v10"
-)
-
-// Validate is a global validator instance
-var Validate = validator.New()
-
-// RegisterCustomValidations registers custom validation functions
-func RegisterCustomValidations() {
- // Register custom validation for issuer
- Validate.RegisterValidation("issuer", validateIssuer)
-
- // Register custom validation for XSS prevention
- Validate.RegisterValidation("no_xss", validateNoXSS)
-
- // Register custom validation for OTP secret
- Validate.RegisterValidation("otpsecret", validateOTPSecret)
-}
-
-// validateOTPSecret validates that the OTP secret is in valid base32 format
-func validateOTPSecret(fl validator.FieldLevel) bool {
- secret := fl.Field().String()
-
- // Check if the secret is not empty
- if secret == "" {
- return false
- }
-
- // Check if the secret is in base32 format (A-Z, 2-7)
- base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
- if !base32Regex.MatchString(secret) {
- return false
- }
-
- // Check if the length is valid (must be at least 16 characters)
- if len(secret) < 16 || len(secret) > 128 {
- return false
- }
-
- return true
-}
-
-)
- if !issuerRegex.MatchString(issuer) {
- return false
- }
-
- // Check length
- if len(issuer) > 100 {
- return false
- }
-
- return true
-}
-
-// validateNoXSS validates that the field doesn't contain potential XSS payloads
-func validateNoXSS(fl validator.FieldLevel) bool {
- value := fl.Field().String()
-
- // Check for HTML encoding
- if strings.Contains(value, "") ||
- strings.Contains(value, "<") ||
- strings.Contains(value, ">") {
- return false
- }
-
- // Check for common XSS patterns
- suspiciousPatterns := []*regexp.Regexp{
- regexp.MustCompile(`(?i)`),
- regexp.MustCompile(`(?i)javascript:`),
- regexp.MustCompile(`(?i)data:text/html`),
- regexp.MustCompile(`(?i)on\w+\s*=`),
- regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
- regexp.MustCompile(`(?i)<\s*iframe`),
- regexp.MustCompile(`(?i)<\s*object`),
- regexp.MustCompile(`(?i)<\s*embed`),
- regexp.MustCompile(`(?i)<\s*style`),
- regexp.MustCompile(`(?i)<\s*form`),
- regexp.MustCompile(`(?i)<\s*applet`),
- regexp.MustCompile(`(?i)<\s*meta`),
- }
-
- for _, pattern := range suspiciousPatterns {
- if pattern.MatchString(value) {
- return false
- }
- }
-
- return true
-}
\ No newline at end of file
diff --git a/api_server.go b/api_server.go
new file mode 100644
index 0000000..26a53bb
--- /dev/null
+++ b/api_server.go
@@ -0,0 +1,202 @@
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+
+ "auth"
+ "config"
+
+ "github.com/gorilla/mux"
+)
+
+// Server represents the API server
+type Server struct {
+ config *config.Config
+ router *mux.Router
+}
+
+type WechatLoginResponse struct {
+ OpenID string `json:"openid"`
+ SessionKey string `json:"session_key"`
+ UnionID string `json:"unionid,omitempty"`
+ ErrCode int `json:"errcode,omitempty"`
+ ErrMsg string `json:"errmsg,omitempty"`
+}
+
+// NewServer creates a new instance of Server
+func NewServer(cfg *config.Config) *Server {
+ s := &Server{
+ config: cfg,
+ router: mux.NewRouter(),
+ }
+ s.setupRoutes()
+ return s
+}
+
+// setupRoutes configures all the routes for the server
+func (s *Server) setupRoutes() {
+ // 公开路由
+ s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
+
+ // 受保护路由(需要JWT)
+ authRouter := s.router.PathPrefix("").Subrouter()
+ authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
+ authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
+ authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
+
+ // 添加CORS中间件
+ s.router.Use(s.corsMiddleware)
+}
+
+// corsMiddleware handles CORS
+func (s *Server) corsMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ origin := r.Header.Get("Origin")
+
+ // Check if the origin is allowed
+ allowed := false
+ for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
+ if origin == allowedOrigin {
+ allowed = true
+ break
+ }
+ }
+
+ if allowed {
+ w.Header().Set("Access-Control-Allow-Origin", origin)
+ w.Header().Set("Access-Control-Allow-Methods",
+ joinStrings(s.config.CORS.AllowedMethods))
+ w.Header().Set("Access-Control-Allow-Headers",
+ joinStrings(s.config.CORS.AllowedHeaders))
+ }
+
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+}
+
+// joinStrings joins string slice with commas
+func joinStrings(slice []string) string {
+ result := ""
+ for i, s := range slice {
+ if i > 0 {
+ result += ", "
+ }
+ result += s
+ }
+ return result
+}
+
+func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // 读取请求体
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, "Bad request", http.StatusBadRequest)
+ return
+ }
+
+ var req struct {
+ Code string `json:"code"`
+ }
+ if err := json.Unmarshal(body, &req); err != nil {
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ // 向微信服务器请求session_key
+ url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
+ s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
+
+ resp, err := http.Get(url)
+ if err != nil {
+ http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ http.Error(w, "Wechat service error", http.StatusInternalServerError)
+ return
+ }
+
+ var wechatResp WechatLoginResponse
+ if err := json.Unmarshal(body, &wechatResp); err != nil {
+ http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
+ return
+ }
+
+ if wechatResp.ErrCode != 0 {
+ http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
+ return
+ }
+
+ // 生成JWT token
+ token, err := s.generateSessionToken(wechatResp.OpenID)
+ if err != nil {
+ http.Error(w, "Failed to generate token", http.StatusInternalServerError)
+ return
+ }
+
+ // 返回响应
+ response := map[string]interface{}{
+ "token": token,
+ "openid": wechatResp.OpenID,
+ }
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ }
+}
+
+func (s *Server) generateSessionToken(openid string) (string, error) {
+ return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
+}
+
+// Start starts the HTTP server
+func (s *Server) Start() error {
+ addr := fmt.Sprintf(":%d", s.config.Server.Port)
+ log.Printf("Starting server on %s", addr)
+
+ srv := &http.Server{
+ Handler: s.router,
+ Addr: addr,
+ WriteTimeout: s.config.Server.Timeout,
+ ReadTimeout: s.config.Server.Timeout,
+ }
+
+ return srv.ListenAndServe()
+}
+
+func main() {
+ // 加载配置
+ cfg, err := config.LoadConfig("config")
+ if err != nil {
+ log.Fatalf("Failed to load config: %v", err)
+ }
+
+ // 初始化数据库连接
+ log.Println("Initializing database connection...")
+ if err := InitDB(cfg.Database); err != nil {
+ log.Fatalf("Failed to initialize database: %v", err)
+ }
+ log.Println("Database connection established successfully")
+
+ // 创建并启动服务器
+ server := NewServer(cfg)
+ if err := server.Start(); err != nil {
+ log.Fatalf("Server failed to start: %v", err)
+ }
+}
diff --git a/auth/middleware.go b/auth/middleware.go
new file mode 100644
index 0000000..a5f8429
--- /dev/null
+++ b/auth/middleware.go
@@ -0,0 +1,105 @@
+package auth
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+type contextKey string
+
+const (
+ UserIDContextKey contextKey = "userID"
+)
+
+// Claims represents the JWT claims
+type Claims struct {
+ UserID string `json:"user_id"`
+ jwt.RegisteredClaims
+}
+
+// GenerateToken generates a new JWT token
+func GenerateToken(userID string, signingKey string, expiry time.Duration) (string, error) {
+ claims := Claims{
+ UserID: userID,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
+ IssuedAt: jwt.NewNumericDate(time.Now()),
+ },
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ return token.SignedString([]byte(signingKey))
+}
+
+// NewAuthMiddleware creates a new authentication middleware
+func NewAuthMiddleware(signingKey string) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Extract token from Authorization header
+ authHeader := r.Header.Get("Authorization")
+ if authHeader == "" {
+ http.Error(w, "Authorization header required", http.StatusUnauthorized)
+ return
+ }
+
+ // Check if the header has the correct format
+ parts := strings.Split(authHeader, " ")
+ if len(parts) != 2 || parts[0] != "Bearer" {
+ http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
+ return
+ }
+
+ tokenString := parts[1]
+
+ // Parse and validate the token
+ claims := &Claims{}
+ token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
+ // Validate signing method
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+ return []byte(signingKey), nil
+ })
+
+ if err != nil {
+ if err == jwt.ErrSignatureInvalid {
+ http.Error(w, "Invalid token signature", http.StatusUnauthorized)
+ return
+ }
+ http.Error(w, "Invalid token", http.StatusUnauthorized)
+ return
+ }
+
+ if !token.Valid {
+ http.Error(w, "Invalid token", http.StatusUnauthorized)
+ return
+ }
+
+ // Add user ID to request context
+ ctx := context.WithValue(r.Context(), UserIDContextKey, claims.UserID)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+ }
+}
+
+// GetUserIDFromContext extracts the user ID from the request context
+func GetUserIDFromContext(ctx context.Context) (string, error) {
+ userID, ok := ctx.Value(UserIDContextKey).(string)
+ if !ok {
+ return "", fmt.Errorf("user ID not found in context")
+ }
+ return userID, nil
+}
+
+// RequireAuth is a middleware that ensures a valid JWT token is present
+func RequireAuth(signingKey string, next http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ authMiddleware := NewAuthMiddleware(signingKey)
+ authMiddleware(http.HandlerFunc(next)).ServeHTTP(w, r)
+ }
+}
diff --git a/cache/cache.go b/cache/cache.go
deleted file mode 100644
index 08d3041..0000000
--- a/cache/cache.go
+++ /dev/null
@@ -1,206 +0,0 @@
-package cache
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "sync"
- "time"
-)
-
-// Item represents a cache item
-type Item struct {
- Value []byte
- Expiration int64
-}
-
-// Expired returns true if the item has expired
-func (item Item) Expired() bool {
- if item.Expiration == 0 {
- return false
- }
- return time.Now().UnixNano() > item.Expiration
-}
-
-// Cache represents an in-memory cache
-type Cache struct {
- items map[string]Item
- mu sync.RWMutex
- defaultExpiration time.Duration
- cleanupInterval time.Duration
- stopCleanup chan bool
-}
-
-// New creates a new cache with the given default expiration and cleanup interval
-func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
- cache := &Cache{
- items: make(map[string]Item),
- defaultExpiration: defaultExpiration,
- cleanupInterval: cleanupInterval,
- stopCleanup: make(chan bool),
- }
-
- // Start cleanup goroutine if cleanup interval > 0
- if cleanupInterval > 0 {
- go cache.startCleanup()
- }
-
- return cache
-}
-
-// startCleanup starts the cleanup process
-func (c *Cache) startCleanup() {
- ticker := time.NewTicker(c.cleanupInterval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- c.DeleteExpired()
- case <-c.stopCleanup:
- return
- }
- }
-}
-
-// Set adds an item to the cache with the given key and expiration
-func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error {
- // Convert value to bytes
- var valueBytes []byte
- var err error
-
- switch v := value.(type) {
- case []byte:
- valueBytes = v
- case string:
- valueBytes = []byte(v)
- default:
- valueBytes, err = json.Marshal(value)
- if err != nil {
- return fmt.Errorf("failed to marshal value: %w", err)
- }
- }
-
- // Calculate expiration
- var exp int64
- if expiration == 0 {
- if c.defaultExpiration > 0 {
- exp = time.Now().Add(c.defaultExpiration).UnixNano()
- }
- } else if expiration > 0 {
- exp = time.Now().Add(expiration).UnixNano()
- }
-
- c.mu.Lock()
- c.items[key] = Item{
- Value: valueBytes,
- Expiration: exp,
- }
- c.mu.Unlock()
-
- return nil
-}
-
-// Get gets an item from the cache
-func (c *Cache) Get(key string, value interface{}) (bool, error) {
- c.mu.RLock()
- item, found := c.items[key]
- c.mu.RUnlock()
-
- if !found {
- return false, nil
- }
-
- // Check if item has expired
- if item.Expired() {
- c.mu.Lock()
- delete(c.items, key)
- c.mu.Unlock()
- return false, nil
- }
-
- // Unmarshal value
- switch v := value.(type) {
- case *[]byte:
- *v = item.Value
- case *string:
- *v = string(item.Value)
- default:
- if err := json.Unmarshal(item.Value, value); err != nil {
- return true, fmt.Errorf("failed to unmarshal value: %w", err)
- }
- }
-
- return true, nil
-}
-
-// Delete deletes an item from the cache
-func (c *Cache) Delete(key string) {
- c.mu.Lock()
- delete(c.items, key)
- c.mu.Unlock()
-}
-
-// DeleteExpired deletes all expired items from the cache
-func (c *Cache) DeleteExpired() {
- now := time.Now().UnixNano()
-
- c.mu.Lock()
- for k, v := range c.items {
- if v.Expiration > 0 && now > v.Expiration {
- delete(c.items, k)
- }
- }
- c.mu.Unlock()
-}
-
-// Clear deletes all items from the cache
-func (c *Cache) Clear() {
- c.mu.Lock()
- c.items = make(map[string]Item)
- c.mu.Unlock()
-}
-
-// Close stops the cleanup goroutine
-func (c *Cache) Close() {
- if c.cleanupInterval > 0 {
- c.stopCleanup <- true
- }
-}
-
-// CacheService provides caching functionality
-type CacheService struct {
- cache *Cache
-}
-
-// NewCacheService creates a new CacheService
-func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService {
- return &CacheService{
- cache: New(defaultExpiration, cleanupInterval),
- }
-}
-
-// Set adds an item to the cache
-func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
- return s.cache.Set(key, value, expiration)
-}
-
-// Get gets an item from the cache
-func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) {
- return s.cache.Get(key, value)
-}
-
-// Delete deletes an item from the cache
-func (s *CacheService) Delete(ctx context.Context, key string) {
- s.cache.Delete(key)
-}
-
-// Clear deletes all items from the cache
-func (s *CacheService) Clear(ctx context.Context) {
- s.cache.Clear()
-}
-
-// Close closes the cache
-func (s *CacheService) Close() {
- s.cache.Close()
-}
diff --git a/cmd/root.go b/cmd/root.go
deleted file mode 100644
index 8821b91..0000000
--- a/cmd/root.go
+++ /dev/null
@@ -1,144 +0,0 @@
-package cmd
-
-import (
- "context"
- "fmt"
- "log"
- "net/http"
- "os"
- "os/signal"
- "syscall"
-
- "github.com/spf13/viper"
-
- "otpm/api"
- "otpm/config"
- "otpm/database"
- "otpm/handlers"
- "otpm/models"
- "otpm/server"
- "otpm/services"
-)
-
-func init() {
- // Set config file with multi-environment support
- viper.SetConfigName("config")
- viper.SetConfigType("yaml")
- viper.AddConfigPath(".")
-
- // Set environment specific config (e.g. config.production.yaml)
- env := os.Getenv("OTPM_ENV")
- if env != "" {
- viper.SetConfigName(fmt.Sprintf("config.%s", env))
- }
-
- // Set default values
- viper.SetDefault("server.port", "8080")
- viper.SetDefault("server.timeout.read", "15s")
- viper.SetDefault("server.timeout.write", "15s")
- viper.SetDefault("server.timeout.idle", "60s")
- viper.SetDefault("database.max_open_conns", 25)
- viper.SetDefault("database.max_idle_conns", 5)
- viper.SetDefault("database.conn_max_lifetime", "5m")
-
- // Set environment variable prefix
- viper.SetEnvPrefix("OTPM")
- viper.AutomaticEnv()
-
- // Bind environment variables
- viper.BindEnv("database.url", "OTPM_DB_URL")
- viper.BindEnv("database.password", "OTPM_DB_PASSWORD")
-}
-
-// Execute is the entry point for the application
-func Execute() error {
- // Load configuration
- cfg, err := config.LoadConfig()
- if err != nil {
- return fmt.Errorf("failed to load config: %w", err)
- }
-
- // Create context with cancellation
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- // Setup signal handling
- sigChan := make(chan os.Signal, 1)
- signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
- go func() {
- sig := <-sigChan
- log.Printf("Received signal: %v", sig)
- cancel()
- }()
-
- // Initialize database
- db, err := database.New(&cfg.Database)
- if err != nil {
- return fmt.Errorf("failed to initialize database: %w", err)
- }
- defer db.Close()
-
- // Run database migrations
- if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil {
- return fmt.Errorf("failed to run migrations: %w", err)
- }
-
- // Initialize repositories
- userRepo := models.NewUserRepository(db.DB)
- otpRepo := models.NewOTPRepository(db.DB)
-
- // Initialize services
- authService := services.NewAuthService(cfg, userRepo)
- otpService := services.NewOTPService(otpRepo)
-
- // Register custom validations
- api.RegisterCustomValidations()
-
- // Initialize handlers
- authHandler := handlers.NewAuthHandler(authService)
- otpHandler := handlers.NewOTPHandler(otpService)
-
- // Create and configure server
- srv := server.New(cfg)
-
- // Register health check endpoint
- srv.RegisterHealthCheck()
-
- // Register public routes with type conversion
- authRoutes := make(map[string]http.Handler)
- for path, handler := range authHandler.Routes() {
- authRoutes[path] = http.HandlerFunc(handler)
- }
- srv.RegisterRoutes(authRoutes)
-
- // Register authenticated routes with type conversion
- otpRoutes := make(map[string]http.Handler)
- for path, handler := range otpHandler.Routes() {
- otpRoutes[path] = http.HandlerFunc(handler)
- }
- srv.RegisterAuthRoutes(otpRoutes)
-
- // Start server in goroutine
- serverErr := make(chan error, 1)
- go func() {
- log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port)
- if err := srv.Start(); err != nil {
- serverErr <- fmt.Errorf("server error: %w", err)
- }
- }()
-
- // Wait for shutdown signal or server error
- select {
- case err := <-serverErr:
- return err
- case <-ctx.Done():
- // Graceful shutdown with timeout
- log.Println("Shutting down server...")
- if err := srv.Shutdown(); err != nil {
- return fmt.Errorf("server shutdown error: %w", err)
- }
- log.Println("Server stopped gracefully")
- }
-
- return nil
-}
diff --git a/config.yaml b/config.yaml
index da2013f..ead2412 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,23 +1,44 @@
+# Server Configuration
server:
port: 8080
- read_timeout: 15s
- write_timeout: 15s
- shutdown_timeout: 5s
+ timeout: 30s
+# Database Configuration
database:
- driver: sqlite3
- dsn: otpm.sqlite
- max_open_conns: 25
- max_idle_conns: 25
- max_lifetime: 5m
- skip_migration: false
+ driver: "sqlite3" # or "postgres"
+ sqlite:
+ path: "./data.db"
+ postgres:
+ host: "localhost"
+ port: 5432
+ user: "postgres"
+ password: "password"
+ dbname: "otpdb"
+ sslmode: "disable"
-jwt:
- secret: "your-jwt-secret-key-change-this-in-production"
- expire_delta: 24h
- refresh_delta: 168h
- signing_method: HS256
+# Security Configuration
+security:
+ encryption_key: "your-32-byte-encryption-key-here"
+ jwt_signing_key: "your-jwt-signing-key-here"
+ token_expiry: 24h
+ refresh_token_expiry: 168h # 7 days
+# WeChat Configuration
wechat:
- app_id: "your-wechat-app-id"
- app_secret: "your-wechat-app-secret"
\ No newline at end of file
+ app_id: "YOUR_APPID"
+ app_secret: "YOUR_APPSECRET"
+
+# CORS Configuration
+cors:
+ allowed_origins:
+ - "http://localhost:8080"
+ - "https://yourdomain.com"
+ allowed_methods:
+ - "GET"
+ - "POST"
+ - "PUT"
+ - "DELETE"
+ - "OPTIONS"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
\ No newline at end of file
diff --git a/config/config.go b/config/config.go
index 4cad24d..1bc9d36 100644
--- a/config/config.go
+++ b/config/config.go
@@ -7,122 +7,156 @@ import (
"github.com/spf13/viper"
)
-// Config holds all configuration for the application
+// Config holds all configuration for our application
type Config struct {
- Server ServerConfig `mapstructure:"server"`
- Database DatabaseConfig `mapstructure:"database"`
- JWT JWTConfig `mapstructure:"jwt"`
- WeChat WeChatConfig `mapstructure:"wechat"`
+ Server ServerConfig
+ Database DatabaseConfig
+ Security SecurityConfig
+ CORS CORSConfig
+ Wechat WechatConfig
}
// ServerConfig holds all server related configuration
type ServerConfig struct {
- Host string `mapstructure:"host"`
- Port int `mapstructure:"port"`
- ReadTimeout time.Duration `mapstructure:"read_timeout"`
- WriteTimeout time.Duration `mapstructure:"write_timeout"`
- ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
- Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout
+ Port int
+ Timeout time.Duration
}
// DatabaseConfig holds all database related configuration
type DatabaseConfig struct {
- Driver string `mapstructure:"driver"`
- DSN string `mapstructure:"dsn"`
- MaxOpenConns int `mapstructure:"max_open_conns"`
- MaxIdleConns int `mapstructure:"max_idle_conns"`
- MaxLifetime time.Duration `mapstructure:"max_lifetime"`
- SkipMigration bool `mapstructure:"skip_migration"`
+ Driver string
+ SQLite SQLiteConfig
+ Postgres PostgresConfig
}
-// JWTConfig holds all JWT related configuration
-type JWTConfig struct {
- Secret string `mapstructure:"secret"`
- ExpireDelta time.Duration `mapstructure:"expire_delta"`
- RefreshDelta time.Duration `mapstructure:"refresh_delta"`
- SigningMethod string `mapstructure:"signing_method"`
- Issuer string `mapstructure:"issuer"`
- Audience string `mapstructure:"audience"`
+// SQLiteConfig holds SQLite specific configuration
+type SQLiteConfig struct {
+ Path string
}
-// WeChatConfig holds all WeChat related configuration
-type WeChatConfig struct {
+// PostgresConfig holds PostgreSQL specific configuration
+type PostgresConfig struct {
+ Host string
+ Port int
+ User string
+ Password string
+ DBName string
+ SSLMode string
+}
+
+// SecurityConfig holds all security related configuration
+type SecurityConfig struct {
+ EncryptionKey string
+ JWTSigningKey string
+ TokenExpiry time.Duration
+ RefreshTokenExpiry time.Duration
+}
+
+// CORSConfig holds CORS related configuration
+type CORSConfig struct {
+ AllowedOrigins []string
+ AllowedMethods []string
+ AllowedHeaders []string
+}
+
+// WechatConfig holds WeChat related configuration
+type WechatConfig struct {
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
}
-// LoadConfig loads the configuration from file and environment variables
-func LoadConfig() (*Config, error) {
- // Set default values
- setDefaults()
+// LoadConfig reads configuration from file or environment variables
+func LoadConfig(configPath string) (*Config, error) {
+ v := viper.New()
- // Read config file
- if err := viper.ReadInConfig(); err != nil {
+ v.SetConfigName("config")
+ v.SetConfigType("yaml")
+ v.AddConfigPath(configPath)
+ v.AddConfigPath(".")
+
+ // Read environment variables
+ v.AutomaticEnv()
+
+ // Allow environment variables to override config file
+ v.SetEnvPrefix("OTP")
+ v.BindEnv("security.encryption_key", "OTP_ENCRYPTION_KEY")
+ v.BindEnv("security.jwt_signing_key", "OTP_JWT_SIGNING_KEY")
+ v.BindEnv("database.postgres.password", "OTP_DB_PASSWORD")
+ v.BindEnv("wechat.app_id", "OTP_WECHAT_APPID")
+ v.BindEnv("wechat.app_secret", "OTP_WECHAT_SECRET")
+
+ if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config Config
- if err := viper.Unmarshal(&config); err != nil {
+ if err := v.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
- // Validate config
+ // Validate required configurations
if err := validateConfig(&config); err != nil {
- return nil, fmt.Errorf("invalid configuration: %w", err)
+ return nil, fmt.Errorf("config validation failed: %w", err)
}
return &config, nil
}
-// setDefaults sets default values for configuration
-func setDefaults() {
- // Server defaults
- viper.SetDefault("server.port", 8080)
- viper.SetDefault("server.read_timeout", "15s")
- viper.SetDefault("server.write_timeout", "15s")
- viper.SetDefault("server.shutdown_timeout", "5s")
- viper.SetDefault("server.timeout", "30s") // Default request processing timeout
-
- // Database defaults
- viper.SetDefault("database.driver", "sqlite3")
- viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
- viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
- viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
- viper.SetDefault("database.skip_migration", false)
-
- // JWT defaults
- viper.SetDefault("jwt.expire_delta", "24h")
- viper.SetDefault("jwt.refresh_delta", "168h") // 7 days
- viper.SetDefault("jwt.signing_method", "HS256")
- viper.SetDefault("jwt.issuer", "otpm")
- viper.SetDefault("jwt.audience", "otpm-client")
-}
-
-// validateConfig validates the configuration
+// validateConfig ensures all required configuration values are provided
func validateConfig(config *Config) error {
- if config.Server.Port < 1 || config.Server.Port > 65535 {
- return fmt.Errorf("invalid port number: %d", config.Server.Port)
+ if config.Security.EncryptionKey == "" {
+ return fmt.Errorf("encryption key is required")
+ }
+
+ if config.Security.JWTSigningKey == "" {
+ return fmt.Errorf("JWT signing key is required")
}
if config.Database.Driver == "" {
return fmt.Errorf("database driver is required")
}
- if config.Database.DSN == "" {
- return fmt.Errorf("database DSN is required")
+ switch config.Database.Driver {
+ case "sqlite3":
+ if config.Database.SQLite.Path == "" {
+ return fmt.Errorf("SQLite database path is required")
+ }
+ case "postgres":
+ if config.Database.Postgres.Host == "" ||
+ config.Database.Postgres.User == "" ||
+ config.Database.Postgres.Password == "" ||
+ config.Database.Postgres.DBName == "" {
+ return fmt.Errorf("incomplete PostgreSQL configuration")
+ }
+ default:
+ return fmt.Errorf("unsupported database driver: %s", config.Database.Driver)
}
- if config.JWT.Secret == "" {
- return fmt.Errorf("JWT secret is required")
- }
-
- if config.WeChat.AppID == "" {
+ // Validate WeChat configuration
+ if config.Wechat.AppID == "" {
return fmt.Errorf("WeChat AppID is required")
}
-
- if config.WeChat.AppSecret == "" {
+ if config.Wechat.AppSecret == "" {
return fmt.Errorf("WeChat AppSecret is required")
}
return nil
}
+
+// GetDSN returns the appropriate database connection string based on the driver
+func (c *DatabaseConfig) GetDSN() string {
+ switch c.Driver {
+ case "sqlite3":
+ return c.SQLite.Path
+ case "postgres":
+ return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
+ c.Postgres.Host,
+ c.Postgres.Port,
+ c.Postgres.User,
+ c.Postgres.Password,
+ c.Postgres.DBName,
+ c.Postgres.SSLMode)
+ default:
+ return ""
+ }
+}
diff --git a/database/db.go b/database/db.go
deleted file mode 100644
index e882713..0000000
--- a/database/db.go
+++ /dev/null
@@ -1,212 +0,0 @@
-package database
-
-import (
- "context"
- "database/sql"
- "fmt"
- "log"
- "strings"
- "time"
-
- "otpm/config"
-
- "github.com/jmoiron/sqlx"
-)
-
-// DB wraps sqlx.DB to provide additional functionality
-type DB struct {
- *sqlx.DB
-}
-
-// New creates a new database connection
-func New(cfg *config.DatabaseConfig) (*DB, error) {
- db, err := sqlx.Open(cfg.Driver, cfg.DSN)
- if err != nil {
- return nil, fmt.Errorf("failed to connect to database: %w", err)
- }
-
- // Configure connection pool based on database type
- if cfg.Driver == "sqlite3" {
- // SQLite is a file-based database - simpler connection settings
- db.SetMaxOpenConns(1)
- db.SetMaxIdleConns(1)
- db.SetConnMaxLifetime(0) // Connections don't need to be recycled
- db.SetConnMaxIdleTime(0)
- } else {
- // For other databases (MySQL, PostgreSQL etc.)
- db.SetMaxOpenConns(cfg.MaxOpenConns)
- db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
- db.SetConnMaxLifetime(30 * time.Minute)
- db.SetConnMaxIdleTime(5 * time.Minute)
- }
-
- // Verify connection with timeout
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
-
- if err := db.PingContext(ctx); err != nil {
- return nil, fmt.Errorf("failed to ping database: %w", err)
- }
-
- log.Println("Successfully connected to database")
- return &DB{db}, nil
-}
-
-// WithTx executes a function within a transaction with retry logic
-func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
- var maxRetries int
- var lastErr error
-
- // Adjust retry settings based on database type
- if db.DriverName() == "sqlite3" {
- maxRetries = 5 // SQLite needs more retries due to busy timeouts
- } else {
- maxRetries = 3
- }
-
- // Default transaction options
- opts := &sql.TxOptions{
- Isolation: sql.LevelReadCommitted,
- }
-
- for attempt := 1; attempt <= maxRetries; attempt++ {
- start := time.Now()
-
- tx, err := db.BeginTxx(ctx, opts)
- if err != nil {
- if isRetryableError(err) && attempt < maxRetries {
- lastErr = err
- time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
- continue
- }
- return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
- }
-
- defer func() {
- if p := recover(); p != nil {
- tx.Rollback()
- panic(p)
- }
- }()
-
- if err := fn(tx); err != nil {
- if rbErr := tx.Rollback(); rbErr != nil {
- log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
- }
-
- if isRetryableError(err) && attempt < maxRetries {
- lastErr = err
- time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
- continue
- }
- return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
- }
-
- // Log long-running transactions
- if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
- log.Printf("Transaction completed in %v", elapsed)
- }
-
- if err := tx.Commit(); err != nil {
- if isRetryableError(err) && attempt < maxRetries {
- lastErr = err
- time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
- continue
- }
- return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
- }
-
- return nil
- }
-
- return lastErr
-}
-
-// isRetryableError checks if an error is likely to succeed on retry
-func isRetryableError(err error) bool {
- if err == nil {
- return false
- }
-
- errStr := strings.ToLower(err.Error())
- return strings.Contains(errStr, "deadlock") ||
- strings.Contains(errStr, "timeout") ||
- strings.Contains(errStr, "try again") ||
- strings.Contains(errStr, "connection reset") ||
- strings.Contains(errStr, "busy") ||
- strings.Contains(errStr, "locked")
-}
-
-// ExecContext executes a query with adaptive timeout
-func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
- // Set timeout based on query complexity
- timeout := 5 * time.Second
- if strings.Contains(strings.ToLower(query), "insert") ||
- strings.Contains(strings.ToLower(query), "update") ||
- strings.Contains(strings.ToLower(query), "delete") {
- timeout = 10 * time.Second
- }
-
- ctx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
-
- start := time.Now()
- _, err := db.DB.ExecContext(ctx, query, args...)
- elapsed := time.Since(start)
-
- // Log slow queries
- if elapsed > timeout/2 {
- log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
- }
-
- if err != nil {
- return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
- }
-
- return nil
-}
-
-// QueryRowContext executes a query that returns a single row with timeout
-func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
-
- return db.DB.QueryRowxContext(ctx, query, args...)
-}
-
-// QueryContext executes a query that returns multiple rows with adaptive timeout
-func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
- // Set timeout based on query complexity
- timeout := 5 * time.Second
- if strings.Contains(strings.ToLower(query), "join") ||
- strings.Contains(strings.ToLower(query), "group by") ||
- strings.Contains(strings.ToLower(query), "order by") {
- timeout = 15 * time.Second
- }
-
- ctx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
-
- start := time.Now()
- rows, err := db.DB.QueryxContext(ctx, query, args...)
- elapsed := time.Since(start)
-
- // Log slow queries
- if elapsed > timeout/2 {
- log.Printf("Slow query detected: %s (took %v)", query, elapsed)
- }
-
- if err != nil {
- return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
- }
-
- return rows, nil
-}
-
-// Close closes the database connection
-func (db *DB) Close() error {
- if err := db.DB.Close(); err != nil {
- return fmt.Errorf("failed to close database connection: %w", err)
- }
- return nil
-}
diff --git a/database/init/otp.sql b/database/init/otp.sql
deleted file mode 100644
index e0e1ba9..0000000
--- a/database/init/otp.sql
+++ /dev/null
@@ -1,26 +0,0 @@
-CREATE TABLE IF NOT EXISTS otp (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id VARCHAR(255) NOT NULL,
- openid VARCHAR(255) NOT NULL,
- name VARCHAR(100) NOT NULL,
- issuer VARCHAR(255),
- secret VARCHAR(255) NOT NULL,
- algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
- digits INTEGER NOT NULL DEFAULT 6,
- period INTEGER NOT NULL DEFAULT 30,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- UNIQUE(user_id, name),
- UNIQUE(openid)
-);
-
--- Add index for faster lookups
-CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
-CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
-
--- Trigger to update the updated_at timestamp
-CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
- AFTER UPDATE ON otp
-BEGIN
- UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
-END;
\ No newline at end of file
diff --git a/database/init/users.sql b/database/init/users.sql
deleted file mode 100644
index ae4532a..0000000
--- a/database/init/users.sql
+++ /dev/null
@@ -1,6 +0,0 @@
-CREATE TABLE IF NOT EXISTS users (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- openid VARCHAR(255) UNIQUE NOT NULL,
- session_key VARCHAR(255) UNIQUE NOT NULL
-);
-CREATE UNIQUE INDEX idx_users_openid ON users(openid);
\ No newline at end of file
diff --git a/database/migration.go b/database/migration.go
deleted file mode 100644
index df15a97..0000000
--- a/database/migration.go
+++ /dev/null
@@ -1,160 +0,0 @@
-package database
-
-import (
- "context"
- "fmt"
- "log"
- "time"
-
- _ "embed"
-
- "github.com/jmoiron/sqlx"
-)
-
-var (
- //go:embed init/users.sql
- userTable string
- //go:embed init/otp.sql
- otpTable string
-)
-
-// Migration represents a database migration
-type Migration struct {
- Name string
- SQL string
- Version int
-}
-
-// Migrations is a list of all migrations
-var Migrations = []Migration{
- {
- Name: "Create users table",
- SQL: userTable,
- Version: 1,
- },
- {
- Name: "Create OTP table",
- SQL: otpTable,
- Version: 2,
- },
-}
-
-// MigrationRecord represents a record in the migrations table
-type MigrationRecord struct {
- ID int `db:"id"`
- Version int `db:"version"`
- Name string `db:"name"`
- AppliedAt time.Time `db:"applied_at"`
-}
-
-// ensureMigrationsTable ensures that the migrations table exists
-func ensureMigrationsTable(db *sqlx.DB) error {
- query := `
- CREATE TABLE IF NOT EXISTS migrations (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- version INTEGER NOT NULL,
- name TEXT NOT NULL,
- applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
- );
- `
-
- _, err := db.Exec(query)
- if err != nil {
- return fmt.Errorf("failed to create migrations table: %w", err)
- }
-
- return nil
-}
-
-// getAppliedMigrations gets all applied migrations
-func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
- var records []MigrationRecord
- if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
- return nil, fmt.Errorf("failed to get applied migrations: %w", err)
- }
-
- result := make(map[int]MigrationRecord)
- for _, record := range records {
- result[record.Version] = record
- }
-
- return result, nil
-}
-
-// Migrate runs all pending migrations
-func Migrate(db *sqlx.DB, skipMigration bool) error {
- if skipMigration {
- log.Println("Skipping database migration as configured")
- return nil
- }
-
- // Ensure migrations table exists
- if err := ensureMigrationsTable(db); err != nil {
- return err
- }
-
- // Get applied migrations
- appliedMigrations, err := getAppliedMigrations(db)
- if err != nil {
- return err
- }
-
- // Apply pending migrations
- for _, migration := range Migrations {
- if _, ok := appliedMigrations[migration.Version]; ok {
- log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
- continue
- }
-
- log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
-
- // Start a transaction for this migration
- tx, err := db.Beginx()
- if err != nil {
- return fmt.Errorf("failed to begin transaction: %w", err)
- }
-
- // Execute migration
- if _, err := tx.Exec(migration.SQL); err != nil {
- tx.Rollback()
- return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
- }
-
- // Record migration
- if _, err := tx.Exec(
- "INSERT INTO migrations (version, name) VALUES (?, ?)",
- migration.Version, migration.Name,
- ); err != nil {
- tx.Rollback()
- return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
- }
-
- // Commit transaction
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
- }
-
- log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
- }
-
- return nil
-}
-
-// MigrateWithContext runs all pending migrations with context
-func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
- // Create a channel to signal completion
- done := make(chan error, 1)
-
- // Run migration in a goroutine
- go func() {
- done <- Migrate(db, skipMigration)
- }()
-
- // Wait for migration to complete or context to be canceled
- select {
- case err := <-done:
- return err
- case <-ctx.Done():
- return fmt.Errorf("migration canceled: %w", ctx.Err())
- }
-}
diff --git a/docs/swagger.go b/docs/swagger.go
deleted file mode 100644
index ccf05ac..0000000
--- a/docs/swagger.go
+++ /dev/null
@@ -1,663 +0,0 @@
-// Package docs provides API documentation using Swagger/OpenAPI
-package docs
-
-import (
- "encoding/json"
- "net/http"
-)
-
-// SwaggerInfo holds the API information
-var SwaggerInfo = struct {
- Title string
- Description string
- Version string
- Host string
- BasePath string
- Schemes []string
-}{
- Title: "OTPM API",
- Description: "API for One-Time Password Manager",
- Version: "1.0.0",
- Host: "localhost:8080",
- BasePath: "/",
- Schemes: []string{"http", "https"},
-}
-
-// SwaggerJSON returns the Swagger JSON
-func SwaggerJSON() []byte {
- swagger := map[string]interface{}{
- "swagger": "2.0",
- "info": map[string]interface{}{
- "title": SwaggerInfo.Title,
- "description": SwaggerInfo.Description,
- "version": SwaggerInfo.Version,
- },
- "host": SwaggerInfo.Host,
- "basePath": SwaggerInfo.BasePath,
- "schemes": SwaggerInfo.Schemes,
- "paths": getPaths(),
- "definitions": map[string]interface{}{
- "LoginRequest": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "code": map[string]interface{}{
- "type": "string",
- "description": "WeChat authorization code",
- },
- },
- "required": []string{"code"},
- },
- "LoginResponse": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "token": map[string]interface{}{
- "type": "string",
- "description": "JWT token",
- },
- "openid": map[string]interface{}{
- "type": "string",
- "description": "WeChat OpenID",
- },
- },
- },
- "CreateOTPRequest": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "name": map[string]interface{}{
- "type": "string",
- "description": "OTP name",
- },
- "issuer": map[string]interface{}{
- "type": "string",
- "description": "OTP issuer",
- },
- "secret": map[string]interface{}{
- "type": "string",
- "description": "OTP secret",
- },
- "algorithm": map[string]interface{}{
- "type": "string",
- "description": "OTP algorithm",
- "enum": []string{"SHA1", "SHA256", "SHA512"},
- },
- "digits": map[string]interface{}{
- "type": "integer",
- "description": "OTP digits",
- "enum": []int{6, 8},
- },
- "period": map[string]interface{}{
- "type": "integer",
- "description": "OTP period in seconds",
- "enum": []int{30, 60},
- },
- },
- "required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"},
- },
- "OTP": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "id": map[string]interface{}{
- "type": "string",
- "description": "OTP ID",
- },
- "user_id": map[string]interface{}{
- "type": "string",
- "description": "User ID",
- },
- "name": map[string]interface{}{
- "type": "string",
- "description": "OTP name",
- },
- "issuer": map[string]interface{}{
- "type": "string",
- "description": "OTP issuer",
- },
- "algorithm": map[string]interface{}{
- "type": "string",
- "description": "OTP algorithm",
- },
- "digits": map[string]interface{}{
- "type": "integer",
- "description": "OTP digits",
- },
- "period": map[string]interface{}{
- "type": "integer",
- "description": "OTP period in seconds",
- },
- "created_at": map[string]interface{}{
- "type": "string",
- "format": "date-time",
- "description": "Creation time",
- },
- "updated_at": map[string]interface{}{
- "type": "string",
- "format": "date-time",
- "description": "Last update time",
- },
- },
- },
- "OTPCodeResponse": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "code": map[string]interface{}{
- "type": "string",
- "description": "OTP code",
- },
- "expires_in": map[string]interface{}{
- "type": "integer",
- "description": "Seconds until expiration",
- },
- },
- },
- "VerifyOTPRequest": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "code": map[string]interface{}{
- "type": "string",
- "description": "OTP code to verify",
- },
- },
- "required": []string{"code"},
- },
- "VerifyOTPResponse": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "valid": map[string]interface{}{
- "type": "boolean",
- "description": "Whether the code is valid",
- },
- },
- },
- "UpdateOTPRequest": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "name": map[string]interface{}{
- "type": "string",
- "description": "OTP name",
- },
- "issuer": map[string]interface{}{
- "type": "string",
- "description": "OTP issuer",
- },
- "algorithm": map[string]interface{}{
- "type": "string",
- "description": "OTP algorithm",
- "enum": []string{"SHA1", "SHA256", "SHA512"},
- },
- "digits": map[string]interface{}{
- "type": "integer",
- "description": "OTP digits",
- "enum": []int{6, 8},
- },
- "period": map[string]interface{}{
- "type": "integer",
- "description": "OTP period in seconds",
- "enum": []int{30, 60},
- },
- },
- },
- "ErrorResponse": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "code": map[string]interface{}{
- "type": "integer",
- "description": "Error code",
- },
- "message": map[string]interface{}{
- "type": "string",
- "description": "Error message",
- },
- },
- },
- },
- "securityDefinitions": map[string]interface{}{
- "Bearer": map[string]interface{}{
- "type": "apiKey",
- "name": "Authorization",
- "in": "header",
- "description": "JWT token with Bearer prefix",
- },
- },
- }
-
- data, _ := json.MarshalIndent(swagger, "", " ")
- return data
-}
-
-// getPaths returns the API paths
-func getPaths() map[string]interface{} {
- return map[string]interface{}{
- "/login": map[string]interface{}{
- "post": map[string]interface{}{
- "summary": "Login with WeChat",
- "description": "Login with WeChat authorization code",
- "tags": []string{"auth"},
- "parameters": []map[string]interface{}{
- {
- "name": "body",
- "in": "body",
- "description": "Login request",
- "required": true,
- "schema": map[string]interface{}{
- "$ref": "#/definitions/LoginRequest",
- },
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "Successful login",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/LoginResponse",
- },
- },
- "400": map[string]interface{}{
- "description": "Invalid request",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- },
- "/verify-token": map[string]interface{}{
- "post": map[string]interface{}{
- "summary": "Verify token",
- "description": "Verify JWT token",
- "tags": []string{"auth"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "Token is valid",
- "schema": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "valid": map[string]interface{}{
- "type": "boolean",
- "description": "Whether the token is valid",
- },
- },
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- },
- "/otp": map[string]interface{}{
- "get": map[string]interface{}{
- "summary": "List OTPs",
- "description": "List all OTPs for the authenticated user",
- "tags": []string{"otp"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "List of OTPs",
- "schema": map[string]interface{}{
- "type": "array",
- "items": map[string]interface{}{
- "$ref": "#/definitions/OTP",
- },
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- "post": map[string]interface{}{
- "summary": "Create OTP",
- "description": "Create a new OTP",
- "tags": []string{"otp"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "parameters": []map[string]interface{}{
- {
- "name": "body",
- "in": "body",
- "description": "OTP creation request",
- "required": true,
- "schema": map[string]interface{}{
- "$ref": "#/definitions/CreateOTPRequest",
- },
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "OTP created",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/OTP",
- },
- },
- "400": map[string]interface{}{
- "description": "Invalid request",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- },
- "/otp/{id}": map[string]interface{}{
- "put": map[string]interface{}{
- "summary": "Update OTP",
- "description": "Update an existing OTP",
- "tags": []string{"otp"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "parameters": []map[string]interface{}{
- {
- "name": "id",
- "in": "path",
- "description": "OTP ID",
- "required": true,
- "type": "string",
- },
- {
- "name": "body",
- "in": "body",
- "description": "OTP update request",
- "required": true,
- "schema": map[string]interface{}{
- "$ref": "#/definitions/UpdateOTPRequest",
- },
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "OTP updated",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/OTP",
- },
- },
- "400": map[string]interface{}{
- "description": "Invalid request",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "404": map[string]interface{}{
- "description": "OTP not found",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- "delete": map[string]interface{}{
- "summary": "Delete OTP",
- "description": "Delete an existing OTP",
- "tags": []string{"otp"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "parameters": []map[string]interface{}{
- {
- "name": "id",
- "in": "path",
- "description": "OTP ID",
- "required": true,
- "type": "string",
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "OTP deleted",
- "schema": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "message": map[string]interface{}{
- "type": "string",
- "description": "Success message",
- },
- },
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "404": map[string]interface{}{
- "description": "OTP not found",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- },
- "/otp/{id}/code": map[string]interface{}{
- "get": map[string]interface{}{
- "summary": "Get OTP code",
- "description": "Get the current OTP code",
- "tags": []string{"otp"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "parameters": []map[string]interface{}{
- {
- "name": "id",
- "in": "path",
- "description": "OTP ID",
- "required": true,
- "type": "string",
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "OTP code",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/OTPCodeResponse",
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "404": map[string]interface{}{
- "description": "OTP not found",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- },
- "/otp/{id}/verify": map[string]interface{}{
- "post": map[string]interface{}{
- "summary": "Verify OTP code",
- "description": "Verify an OTP code",
- "tags": []string{"otp"},
- "security": []map[string]interface{}{
- {
- "Bearer": []string{},
- },
- },
- "parameters": []map[string]interface{}{
- {
- "name": "id",
- "in": "path",
- "description": "OTP ID",
- "required": true,
- "type": "string",
- },
- {
- "name": "body",
- "in": "body",
- "description": "OTP verification request",
- "required": true,
- "schema": map[string]interface{}{
- "$ref": "#/definitions/VerifyOTPRequest",
- },
- },
- },
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "OTP verification result",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/VerifyOTPResponse",
- },
- },
- "400": map[string]interface{}{
- "description": "Invalid request",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "401": map[string]interface{}{
- "description": "Unauthorized",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "404": map[string]interface{}{
- "description": "OTP not found",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- "500": map[string]interface{}{
- "description": "Internal server error",
- "schema": map[string]interface{}{
- "$ref": "#/definitions/ErrorResponse",
- },
- },
- },
- },
- },
- "/health": map[string]interface{}{
- "get": map[string]interface{}{
- "summary": "Health check",
- "description": "Check if the API is healthy",
- "tags": []string{"system"},
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "API is healthy",
- "schema": map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "status": map[string]interface{}{
- "type": "string",
- "description": "Health status",
- },
- "time": map[string]interface{}{
- "type": "string",
- "format": "date-time",
- "description": "Current time",
- },
- },
- },
- },
- },
- },
- },
- "/metrics": map[string]interface{}{
- "get": map[string]interface{}{
- "summary": "Metrics",
- "description": "Get application metrics",
- "tags": []string{"system"},
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "Application metrics",
- },
- },
- },
- },
- "/swagger.json": map[string]interface{}{
- "get": map[string]interface{}{
- "summary": "Swagger JSON",
- "description": "Get Swagger JSON",
- "tags": []string{"system"},
- "responses": map[string]interface{}{
- "200": map[string]interface{}{
- "description": "Swagger JSON",
- },
- },
- },
- },
- }
-}
-
-// Handler returns an HTTP handler for Swagger JSON
-func Handler() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- w.Write(SwaggerJSON())
- }
-}
diff --git a/go.mod b/go.mod
deleted file mode 100644
index a55608f..0000000
--- a/go.mod
+++ /dev/null
@@ -1,50 +0,0 @@
-module otpm
-
-go 1.23.0
-
-toolchain go1.23.9
-
-require (
- github.com/go-playground/validator/v10 v10.26.0
- github.com/golang-jwt/jwt v3.2.2+incompatible
- github.com/google/uuid v1.6.0
- github.com/jmoiron/sqlx v1.4.0
- github.com/julienschmidt/httprouter v1.3.0
- github.com/prometheus/client_golang v1.22.0
- github.com/spf13/viper v1.19.0
- golang.org/x/crypto v0.38.0
-)
-
-require (
- github.com/beorn7/perks v1.0.1 // indirect
- github.com/cespare/xxhash/v2 v2.3.0 // indirect
- github.com/fsnotify/fsnotify v1.7.0 // indirect
- github.com/gabriel-vasile/mimetype v1.4.8 // indirect
- github.com/go-playground/locales v0.14.1 // indirect
- github.com/go-playground/universal-translator v0.18.1 // indirect
- github.com/hashicorp/hcl v1.0.0 // indirect
- github.com/leodido/go-urn v1.4.0 // indirect
- github.com/magiconair/properties v1.8.7 // indirect
- github.com/mitchellh/mapstructure v1.5.0 // indirect
- github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
- github.com/pelletier/go-toml/v2 v2.2.2 // indirect
- github.com/prometheus/client_model v0.6.1 // indirect
- github.com/prometheus/common v0.62.0 // indirect
- github.com/prometheus/procfs v0.15.1 // indirect
- github.com/sagikazarmark/locafero v0.4.0 // indirect
- github.com/sagikazarmark/slog-shim v0.1.0 // indirect
- github.com/sourcegraph/conc v0.3.0 // indirect
- github.com/spf13/afero v1.11.0 // indirect
- github.com/spf13/cast v1.6.0 // indirect
- github.com/spf13/pflag v1.0.5 // indirect
- github.com/subosito/gotenv v1.6.0 // indirect
- go.uber.org/atomic v1.9.0 // indirect
- go.uber.org/multierr v1.9.0 // indirect
- golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
- golang.org/x/net v0.34.0 // indirect
- golang.org/x/sys v0.33.0 // indirect
- golang.org/x/text v0.25.0 // indirect
- google.golang.org/protobuf v1.36.5 // indirect
- gopkg.in/ini.v1 v1.67.0 // indirect
- gopkg.in/yaml.v3 v3.0.1 // indirect
-)
diff --git a/go.sum b/go.sum
deleted file mode 100644
index c30de80..0000000
--- a/go.sum
+++ /dev/null
@@ -1,124 +0,0 @@
-filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
-filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
-github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
-github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
-github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
-github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
-github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
-github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
-github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
-github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
-github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
-github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
-github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
-github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
-github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
-github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
-github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
-github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
-github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
-github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
-github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
-github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
-github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
-github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
-github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
-github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
-github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
-github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
-github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
-github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
-github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
-github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
-github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
-github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
-github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
-github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
-github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
-github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
-github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
-github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
-github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
-github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
-github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
-github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
-github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
-github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
-github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
-github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
-github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
-github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
-github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
-github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
-github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
-github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
-github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
-github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
-github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
-github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
-github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
-github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
-github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
-github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
-github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
-github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
-github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
-github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
-github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
-github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
-github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
-github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
-github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
-github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
-github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
-github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
-go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
-go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
-go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
-go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
-golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
-golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
-golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
-golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
-golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
-golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
-golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
-golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
-golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
-golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
-google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
-google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
-gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
-gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go
deleted file mode 100644
index f79979f..0000000
--- a/handlers/auth_handler.go
+++ /dev/null
@@ -1,156 +0,0 @@
-package handlers
-
-import (
- "encoding/json"
- "fmt"
- "log"
- "net/http"
- "time"
-
- "otpm/api"
- "otpm/services"
-
- "github.com/golang-jwt/jwt"
- "github.com/julienschmidt/httprouter"
-)
-
-// AuthHandler handles authentication related requests
-type AuthHandler struct {
- authService *services.AuthService
-}
-
-// NewAuthHandler creates a new AuthHandler
-func NewAuthHandler(authService *services.AuthService) *AuthHandler {
- return &AuthHandler{
- authService: authService,
- }
-}
-
-// LoginRequest represents a login request
-type LoginRequest struct {
- Code string `json:"code" validate:"required,min=32,max=128"`
-}
-
-// LoginResponse represents a login response
-type LoginResponse struct {
- Token string `json:"token"`
- OpenID string `json:"openid"`
-}
-
-// TokenRequest represents a token verification request
-type TokenRequest struct {
- Token string `validate:"required,min=32"`
-}
-
-// Login handles WeChat login
-func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
- start := time.Now()
-
- // Limit request body size to prevent DOS
- r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
-
- // Parse and validate request
- var req LoginRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- fmt.Sprintf("Invalid request body: %v", err))
- log.Printf("Login request parse error: %v", err)
- return
- }
-
- // Validate using validator
- if err := api.Validate.Struct(req); err != nil {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- fmt.Sprintf("Invalid request parameters: %v", err))
- log.Printf("Login request validation failed: %v", err)
- return
- }
-
- // Login with WeChat code
- token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- log.Printf("Login failed for code %s: %v", req.Code, err)
- return
- }
-
- // Log successful login
- log.Printf("Login successful for code %s (took %v)",
- req.Code, time.Since(start))
-
- // Return token
- api.NewResponseWriter(w).WriteSuccess(LoginResponse{
- Token: token,
- })
-}
-
-// VerifyToken handles token verification
-func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
- start := time.Now()
-
- // Get token from Authorization header
- authHeader := r.Header.Get("Authorization")
- if authHeader == "" {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Authorization header is required")
- log.Printf("Token verification failed: missing Authorization header")
- return
- }
-
- // Validate token format
- if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Invalid token format. Expected 'Bearer '")
- log.Printf("Token verification failed: invalid token format")
- return
- }
-
- token := authHeader[7:]
-
- // Validate token using validator
- tokenReq := TokenRequest{Token: token}
- if err := api.Validate.Struct(tokenReq); err != nil {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Invalid token format")
- log.Printf("Token verification failed: %v", err)
- return
- }
-
- // Validate token
- claims, err := h.authService.ValidateToken(token)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- log.Printf("Token verification failed for token %s: %v",
- maskToken(token), err) // Mask token in logs
- return
- }
-
- // Log successful verification
- userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
- if !ok {
- log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
- } else {
- log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
- }
-
- // Token is valid
- api.NewResponseWriter(w).WriteSuccess(map[string]bool{
- "valid": true,
- })
-}
-
-// maskToken masks sensitive parts of token for logging
-func maskToken(token string) string {
- if len(token) < 8 {
- return "****"
- }
- return token[:4] + "****" + token[len(token)-4:]
-}
-
-// Routes returns all routes for the auth handler
-func (h *AuthHandler) Routes() map[string]httprouter.Handle {
- return map[string]httprouter.Handle{
- "/api/login": h.Login,
- "/api/verify-token": h.VerifyToken,
- }
-}
diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go
deleted file mode 100644
index c1d99bb..0000000
--- a/handlers/otp_handler.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package handlers
-
-import (
- "encoding/json"
- "net/http"
-
- "github.com/julienschmidt/httprouter"
-
- "otpm/api"
- "otpm/middleware"
- "otpm/models"
- "otpm/services"
-)
-
-// OTPHandler handles OTP-related HTTP requests
-type OTPHandler struct {
- otpService *services.OTPService
-}
-
-// NewOTPHandler creates a new OTPHandler
-func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
- return &OTPHandler{
- otpService: otpService,
- }
-}
-
-// Routes returns the routes for OTP operations
-func (h *OTPHandler) Routes() map[string]httprouter.Handle {
- return map[string]httprouter.Handle{
- "POST /api/otp": h.CreateOTP,
- "GET /api/otps": h.ListOTPs,
- "GET /api/otp/:id": h.GetOTP,
- }
-}
-
-// CreateOTP handles the creation of a new OTP
-func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
- // Get user ID from context
- userID, ok := r.Context().Value(middleware.UserIDKey).(string)
- if !ok {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- return
- }
-
- // Parse request body
- var params models.OTPParams
- if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
- api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
- return
- }
-
- // Validate request
- if err := api.Validate.Struct(params); err != nil {
- api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
- return
- }
-
- // Create OTP
- otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- return
- }
-
- // Return response
- api.NewResponseWriter(w).WriteSuccess(otp)
-}
-
-// ListOTPs handles listing all OTPs for a user
-func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
- // Get user ID from context
- userID, ok := r.Context().Value(middleware.UserIDKey).(string)
- if !ok {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- return
- }
-
- // Get OTPs
- otps, err := h.otpService.ListOTPs(r.Context(), userID)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- return
- }
-
- // Return response
- api.NewResponseWriter(w).WriteSuccess(otps)
-}
-
-// GetOTP handles getting a specific OTP
-func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
- // Get user ID from context
- userID, ok := r.Context().Value(middleware.UserIDKey).(string)
- if !ok {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- return
- }
-
- // Get OTP ID from URL
- otpID := ps.ByName("id")
- if otpID == "" {
- api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
- return
- }
-
- // Get OTP
- otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- return
- }
-
- // Return response
- api.NewResponseWriter(w).WriteSuccess(otp)
-}
diff --git a/init/postgresql/init.sql b/init/postgresql/init.sql
new file mode 100644
index 0000000..f734604
--- /dev/null
+++ b/init/postgresql/init.sql
@@ -0,0 +1,51 @@
+-- 创建tokens表
+CREATE TABLE IF NOT EXISTS tokens (
+ id VARCHAR(255) NOT NULL, -- token的唯一标识符
+ user_id VARCHAR(255) NOT NULL, -- 用户ID
+ issuer VARCHAR(255) NOT NULL, -- 令牌发行者
+ account VARCHAR(255) NOT NULL, -- 账户名称
+ secret TEXT NOT NULL, -- 密钥
+ type VARCHAR(10) NOT NULL, -- 令牌类型(totp/hotp)
+ counter INTEGER, -- HOTP计数器(可选)
+ period INTEGER NOT NULL, -- TOTP周期(秒)
+ digits INTEGER NOT NULL, -- 验证码位数
+ algo VARCHAR(10) NOT NULL, -- 使用的哈希算法
+ timestamp BIGINT NOT NULL, -- 最后更新时间戳
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (id, user_id)
+);
+
+-- 创建更新时间戳的触发器
+CREATE OR REPLACE FUNCTION update_updated_at_column()
+RETURNS TRIGGER AS $$
+BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+END;
+$$ language 'plpgsql';
+
+CREATE TRIGGER update_tokens_updated_at
+ BEFORE UPDATE ON tokens
+ FOR EACH ROW
+ EXECUTE FUNCTION update_updated_at_column();
+
+-- 创建索引
+CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
+CREATE INDEX IF NOT EXISTS idx_tokens_timestamp ON tokens(timestamp);
+
+-- 添加注释
+COMMENT ON TABLE tokens IS 'OTP令牌数据表';
+COMMENT ON COLUMN tokens.id IS '令牌的唯一标识符';
+COMMENT ON COLUMN tokens.user_id IS '用户ID';
+COMMENT ON COLUMN tokens.issuer IS '令牌发行者';
+COMMENT ON COLUMN tokens.account IS '账户名称';
+COMMENT ON COLUMN tokens.secret IS '密钥';
+COMMENT ON COLUMN tokens.type IS '令牌类型(totp/hotp)';
+COMMENT ON COLUMN tokens.counter IS 'HOTP计数器(可选)';
+COMMENT ON COLUMN tokens.period IS 'TOTP周期(秒)';
+COMMENT ON COLUMN tokens.digits IS '验证码位数';
+COMMENT ON COLUMN tokens.algo IS '使用的哈希算法';
+COMMENT ON COLUMN tokens.timestamp IS '最后更新时间戳';
+COMMENT ON COLUMN tokens.created_at IS '创建时间';
+COMMENT ON COLUMN tokens.updated_at IS '最后更新时间';
\ No newline at end of file
diff --git a/init/sqlite3/init.sql b/init/sqlite3/init.sql
new file mode 100644
index 0000000..6f9f3fa
--- /dev/null
+++ b/init/sqlite3/init.sql
@@ -0,0 +1,50 @@
+-- SQLite3 initialization SQL
+
+-- Enable WAL mode for better concurrency (simple performance boost)
+PRAGMA journal_mode = WAL;
+PRAGMA synchronous = NORMAL;
+
+-- Enable foreign key support
+PRAGMA foreign_keys = ON;
+
+-- 创建tokens表
+CREATE TABLE IF NOT EXISTS tokens (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT NOT NULL,
+ issuer TEXT NOT NULL,
+ account TEXT NOT NULL,
+ secret TEXT NOT NULL CHECK (length(secret) >= 16 AND secret REGEXP '^[A-Z2-7]+=*$'),
+ type TEXT NOT NULL CHECK (type IN ('HOTP', 'TOTP')),
+ counter INTEGER CHECK (
+ (type = 'HOTP' AND counter >= 0) OR
+ (type = 'TOTP' AND counter IS NULL)
+ ),
+ period INTEGER DEFAULT 30 CHECK (
+ (type = 'TOTP' AND period >= 30) OR
+ (type = 'HOTP' AND period IS NULL)
+ ),
+ digits INTEGER NOT NULL DEFAULT 6 CHECK (digits IN (6, 8)),
+ algo TEXT NOT NULL DEFAULT 'SHA1' CHECK (algo IN ('SHA1', 'SHA256', 'SHA512')),
+ created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
+ updated_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
+ UNIQUE(user_id, issuer, account)
+);
+
+-- 基本索引
+CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
+CREATE INDEX IF NOT EXISTS idx_tokens_lookup ON tokens(user_id, issuer, account);
+CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'HOTP';
+CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'TOTP';
+
+-- 简化统计视图
+CREATE VIEW IF NOT EXISTS v_token_stats AS
+SELECT
+ user_id,
+ COUNT(*) as total_tokens,
+ SUM(type = 'HOTP') as hotp_count,
+ SUM(type = 'TOTP') as totp_count
+FROM tokens
+GROUP BY user_id;
+
+-- 设置版本号
+PRAGMA user_version = 1;
\ No newline at end of file
diff --git a/logger/logger.go b/logger/logger.go
deleted file mode 100644
index d7d76e4..0000000
--- a/logger/logger.go
+++ /dev/null
@@ -1,204 +0,0 @@
-package logger
-
-import (
- "context"
- "fmt"
- "io"
- "os"
- "runtime"
- "strings"
- "time"
-
- "github.com/google/uuid"
-)
-
-// Level represents a log level
-type Level int
-
-const (
- // DEBUG level
- DEBUG Level = iota
- // INFO level
- INFO
- // WARN level
- WARN
- // ERROR level
- ERROR
- // FATAL level
- FATAL
-)
-
-// String returns the string representation of the log level
-func (l Level) String() string {
- switch l {
- case DEBUG:
- return "DEBUG"
- case INFO:
- return "INFO"
- case WARN:
- return "WARN"
- case ERROR:
- return "ERROR"
- case FATAL:
- return "FATAL"
- default:
- return "UNKNOWN"
- }
-}
-
-// Logger represents a logger
-type Logger struct {
- level Level
- output io.Writer
-}
-
-// contextKey is a type for context keys
-type contextKey string
-
-// requestIDKey is the key for request ID in context
-const requestIDKey = contextKey("request_id")
-
-// New creates a new logger
-func New(level Level, output io.Writer) *Logger {
- if output == nil {
- output = os.Stdout
- }
- return &Logger{
- level: level,
- output: output,
- }
-}
-
-// WithLevel creates a new logger with the specified level
-func (l *Logger) WithLevel(level Level) *Logger {
- return &Logger{
- level: level,
- output: l.output,
- }
-}
-
-// WithOutput creates a new logger with the specified output
-func (l *Logger) WithOutput(output io.Writer) *Logger {
- return &Logger{
- level: l.level,
- output: output,
- }
-}
-
-// log logs a message with the specified level
-func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) {
- if level < l.level {
- return
- }
-
- // Get request ID from context
- requestID := getRequestID(ctx)
-
- // Get caller information
- _, file, line, ok := runtime.Caller(2)
- if !ok {
- file = "unknown"
- line = 0
- }
- // Extract just the filename
- if idx := strings.LastIndex(file, "/"); idx >= 0 {
- file = file[idx+1:]
- }
-
- // Format message
- message := fmt.Sprintf(format, args...)
-
- // Format log entry
- timestamp := time.Now().Format(time.RFC3339)
- logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n",
- timestamp, level.String(), file, line, requestID, message)
-
- // Write log entry
- _, _ = l.output.Write([]byte(logEntry))
-
- // Exit if fatal
- if level == FATAL {
- os.Exit(1)
- }
-}
-
-// Debug logs a debug message
-func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
- l.log(ctx, DEBUG, format, args...)
-}
-
-// Info logs an info message
-func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
- l.log(ctx, INFO, format, args...)
-}
-
-// Warn logs a warning message
-func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
- l.log(ctx, WARN, format, args...)
-}
-
-// Error logs an error message
-func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
- l.log(ctx, ERROR, format, args...)
-}
-
-// Fatal logs a fatal message and exits
-func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
- l.log(ctx, FATAL, format, args...)
-}
-
-// WithRequestID adds a request ID to the context
-func WithRequestID(ctx context.Context) context.Context {
- requestID := uuid.New().String()
- return context.WithValue(ctx, requestIDKey, requestID)
-}
-
-// GetRequestID gets the request ID from the context
-func GetRequestID(ctx context.Context) string {
- return getRequestID(ctx)
-}
-
-// getRequestID gets the request ID from the context
-func getRequestID(ctx context.Context) string {
- if ctx == nil {
- return "-"
- }
- requestID, ok := ctx.Value(requestIDKey).(string)
- if !ok {
- return "-"
- }
- return requestID
-}
-
-// Default logger
-var defaultLogger = New(INFO, os.Stdout)
-
-// SetDefaultLogger sets the default logger
-func SetDefaultLogger(logger *Logger) {
- defaultLogger = logger
-}
-
-// Debug logs a debug message using the default logger
-func Debug(ctx context.Context, format string, args ...interface{}) {
- defaultLogger.Debug(ctx, format, args...)
-}
-
-// Info logs an info message using the default logger
-func Info(ctx context.Context, format string, args ...interface{}) {
- defaultLogger.Info(ctx, format, args...)
-}
-
-// Warn logs a warning message using the default logger
-func Warn(ctx context.Context, format string, args ...interface{}) {
- defaultLogger.Warn(ctx, format, args...)
-}
-
-// Error logs an error message using the default logger
-func Error(ctx context.Context, format string, args ...interface{}) {
- defaultLogger.Error(ctx, format, args...)
-}
-
-// Fatal logs a fatal message and exits using the default logger
-func Fatal(ctx context.Context, format string, args ...interface{}) {
- defaultLogger.Fatal(ctx, format, args...)
-}
diff --git a/main.go b/main.go
deleted file mode 100644
index 1b34398..0000000
--- a/main.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package main
-
-import "otpm/cmd"
-
-func main() {
- cmd.Execute()
-}
diff --git a/metrics/metrics.go b/metrics/metrics.go
deleted file mode 100644
index 8e1d18a..0000000
--- a/metrics/metrics.go
+++ /dev/null
@@ -1,193 +0,0 @@
-package metrics
-
-import (
- "fmt"
- "net/http"
- "sync"
- "time"
-
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promhttp"
-)
-
-var (
- // Default metrics
- requestDuration = prometheus.NewHistogramVec(
- prometheus.HistogramOpts{
- Name: "http_request_duration_seconds",
- Help: "Duration of HTTP requests in seconds",
- Buckets: prometheus.DefBuckets,
- },
- []string{"method", "path", "status"},
- )
-
- requestTotal = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "http_requests_total",
- Help: "Total number of HTTP requests",
- },
- []string{"method", "path", "status"},
- )
-
- otpGenerationTotal = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "otp_generation_total",
- Help: "Total number of OTP generations",
- },
- []string{"user_id", "otp_id"},
- )
-
- otpVerificationTotal = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "otp_verification_total",
- Help: "Total number of OTP verifications",
- },
- []string{"user_id", "otp_id", "success"},
- )
-
- activeUsers = prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "active_users",
- Help: "Number of active users",
- },
- )
-
- cacheHits = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "cache_hits_total",
- Help: "Total number of cache hits",
- },
- []string{"cache"},
- )
-
- cacheMisses = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "cache_misses_total",
- Help: "Total number of cache misses",
- },
- []string{"cache"},
- )
-)
-
-func init() {
- // Register metrics with prometheus
- prometheus.MustRegister(
- requestDuration,
- requestTotal,
- otpGenerationTotal,
- otpVerificationTotal,
- activeUsers,
- cacheHits,
- cacheMisses,
- )
-}
-
-// MetricsService provides metrics functionality
-type MetricsService struct {
- activeUsersMutex sync.RWMutex
- activeUserIDs map[string]bool
-}
-
-// NewMetricsService creates a new MetricsService
-func NewMetricsService() *MetricsService {
- return &MetricsService{
- activeUserIDs: make(map[string]bool),
- }
-}
-
-// Handler returns an HTTP handler for metrics
-func (s *MetricsService) Handler() http.Handler {
- return promhttp.Handler()
-}
-
-// RecordRequest records metrics for an HTTP request
-func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) {
- labels := prometheus.Labels{
- "method": method,
- "path": path,
- "status": fmt.Sprintf("%d", status),
- }
-
- requestDuration.With(labels).Observe(duration.Seconds())
- requestTotal.With(labels).Inc()
-}
-
-// RecordOTPGeneration records metrics for OTP generation
-func (s *MetricsService) RecordOTPGeneration(userID, otpID string) {
- otpGenerationTotal.With(prometheus.Labels{
- "user_id": userID,
- "otp_id": otpID,
- }).Inc()
-}
-
-// RecordOTPVerification records metrics for OTP verification
-func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) {
- otpVerificationTotal.With(prometheus.Labels{
- "user_id": userID,
- "otp_id": otpID,
- "success": fmt.Sprintf("%t", success),
- }).Inc()
-}
-
-// RecordUserActivity records user activity
-func (s *MetricsService) RecordUserActivity(userID string) {
- s.activeUsersMutex.Lock()
- defer s.activeUsersMutex.Unlock()
-
- if !s.activeUserIDs[userID] {
- s.activeUserIDs[userID] = true
- activeUsers.Inc()
- }
-}
-
-// RecordUserInactivity records user inactivity
-func (s *MetricsService) RecordUserInactivity(userID string) {
- s.activeUsersMutex.Lock()
- defer s.activeUsersMutex.Unlock()
-
- if s.activeUserIDs[userID] {
- delete(s.activeUserIDs, userID)
- activeUsers.Dec()
- }
-}
-
-// RecordCacheHit records a cache hit
-func (s *MetricsService) RecordCacheHit(cache string) {
- cacheHits.With(prometheus.Labels{
- "cache": cache,
- }).Inc()
-}
-
-// RecordCacheMiss records a cache miss
-func (s *MetricsService) RecordCacheMiss(cache string) {
- cacheMisses.With(prometheus.Labels{
- "cache": cache,
- }).Inc()
-}
-
-// Middleware creates a middleware that records request metrics
-func (s *MetricsService) Middleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
-
- // Create response writer that captures status code
- rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
-
- // Call next handler
- next.ServeHTTP(rw, r)
-
- // Record metrics
- s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start))
- })
-}
-
-// responseWriter wraps http.ResponseWriter to capture status code
-type responseWriter struct {
- http.ResponseWriter
- status int
-}
-
-func (rw *responseWriter) WriteHeader(code int) {
- rw.status = code
- rw.ResponseWriter.WriteHeader(code)
-}
diff --git a/middleware/middleware.go b/middleware/middleware.go
deleted file mode 100644
index 6afce13..0000000
--- a/middleware/middleware.go
+++ /dev/null
@@ -1,353 +0,0 @@
-package middleware
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "math/rand"
- "net/http"
- "runtime/debug"
- "strings"
- "time"
-
- "github.com/golang-jwt/jwt"
-)
-
-// Response represents a standard API response
-type Response struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data interface{} `json:"data,omitempty"`
-}
-
-// ErrorResponse sends a JSON error response
-func ErrorResponse(w http.ResponseWriter, code int, message string) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(code)
- json.NewEncoder(w).Encode(Response{
- Code: code,
- Message: message,
- })
-}
-
-// SuccessResponse sends a JSON success response
-func SuccessResponse(w http.ResponseWriter, data interface{}) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- json.NewEncoder(w).Encode(Response{
- Code: http.StatusOK,
- Message: "success",
- Data: data,
- })
-}
-
-// Logger is a middleware that logs request details with structured format
-func Logger(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
- requestID := r.Header.Get("X-Request-ID")
- if requestID == "" {
- requestID = generateRequestID()
- r.Header.Set("X-Request-ID", requestID)
- }
-
- // Create a custom response writer to capture status code
- rw := &responseWriter{
- ResponseWriter: w,
- status: http.StatusOK,
- }
-
- // Process request
- next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
-
- // Log structured request details
- log.Printf(
- "method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
- r.Method,
- r.URL.Path,
- rw.status,
- time.Since(start).String(),
- r.RemoteAddr,
- requestID,
- )
- })
-}
-
-// generateRequestID creates a unique request identifier
-func generateRequestID() string {
- return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
-}
-
-// randomString generates a random string of given length
-func randomString(length int) string {
- const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
- b := make([]byte, length)
- for i := range b {
- b[i] = charset[rand.Intn(len(charset))]
- }
- return string(b)
-}
-
-// Recover is a middleware that recovers from panics with detailed logging
-func Recover(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- defer func() {
- if err := recover(); err != nil {
- // Get request ID from context
- requestID := ""
- if ctx := r.Context(); ctx != nil {
- if id, ok := ctx.Value("request_id").(string); ok {
- requestID = id
- }
- }
-
- // Log error with stack trace and request context
- log.Printf(
- "panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
- err,
- requestID,
- r.Method,
- r.URL.Path,
- r.RemoteAddr,
- debug.Stack(),
- )
-
- // Determine error type
- var message string
- var status int
-
- switch e := err.(type) {
- case error:
- message = e.Error()
- if isClientError(e) {
- status = http.StatusBadRequest
- } else {
- status = http.StatusInternalServerError
- }
- case string:
- message = e
- status = http.StatusInternalServerError
- default:
- message = "Internal Server Error"
- status = http.StatusInternalServerError
- }
-
- ErrorResponse(w, status, message)
- }
- }()
- next.ServeHTTP(w, r)
- })
-}
-
-// isClientError checks if error should be treated as client error
-func isClientError(err error) bool {
- // Add more client error types as needed
- return strings.Contains(err.Error(), "validation") ||
- strings.Contains(err.Error(), "invalid") ||
- strings.Contains(err.Error(), "missing")
-}
-
-// CORS is a middleware that handles CORS
-func CORS(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
- w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
-
- if r.Method == "OPTIONS" {
- w.WriteHeader(http.StatusOK)
- return
- }
-
- next.ServeHTTP(w, r)
- })
-}
-
-// Timeout is a middleware that safely handles request timeouts
-func Timeout(duration time.Duration) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ctx, cancel := context.WithTimeout(r.Context(), duration)
- defer cancel()
-
- // Use buffered channels to prevent goroutine leaks
- done := make(chan struct{}, 1)
- panicChan := make(chan interface{}, 1)
-
- // Track request processing in goroutine
- go func() {
- defer func() {
- if p := recover(); p != nil {
- panicChan <- p
- }
- }()
- next.ServeHTTP(w, r.WithContext(ctx))
- done <- struct{}{}
- }()
-
- // Wait for completion, timeout or panic
- select {
- case <-done:
- return
- case p := <-panicChan:
- panic(p) // Re-throw panic to be caught by Recover middleware
- case <-ctx.Done():
- // Get request context for logging
- requestID := ""
- if ctx := r.Context(); ctx != nil {
- if id, ok := ctx.Value("request_id").(string); ok {
- requestID = id
- }
- }
-
- // Log timeout details
- log.Printf(
- "request_timeout: request_id=%s method=%s path=%s timeout=%s",
- requestID,
- r.Method,
- r.URL.Path,
- duration.String(),
- )
-
- // Send timeout response
- ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
- "Request timed out after %s", duration.String(),
- ))
- }
- })
- }
-}
-
-// Auth is a middleware that validates JWT tokens with enhanced security
-func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Get request ID for logging
- requestID := ""
- if ctx := r.Context(); ctx != nil {
- if id, ok := ctx.Value("request_id").(string); ok {
- requestID = id
- }
- }
-
- // Get token from Authorization header
- authHeader := r.Header.Get("Authorization")
- if authHeader == "" {
- log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
- ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
- return
- }
-
- // Validate header format
- parts := strings.Split(authHeader, " ")
- if len(parts) != 2 || parts[0] != "Bearer" {
- log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
- ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer '")
- return
- }
-
- tokenString := parts[1]
-
- // Parse and validate token
- token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
- // Validate signing method
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
- }
- return []byte(jwtSecret), nil
- })
-
- if err != nil {
- log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
- ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
- return
- }
-
- if !token.Valid {
- log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
- ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
- return
- }
-
- // Validate claims
- claims, ok := token.Claims.(jwt.MapClaims)
- if !ok {
- log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
- ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
- return
- }
-
- // Check required claims
- userID, ok := claims["user_id"].(string)
- if !ok || userID == "" {
- log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
- ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
- return
- }
-
- // Check token expiration
- if exp, ok := claims["exp"].(float64); ok {
- if time.Now().Unix() > int64(exp) {
- log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
- ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
- return
- }
- }
-
- // Check required roles if specified
- if len(requiredRoles) > 0 {
- roles, ok := claims["roles"].([]interface{})
- if !ok {
- log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
- ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
- return
- }
-
- hasRequiredRole := false
- for _, requiredRole := range requiredRoles {
- for _, role := range roles {
- if r, ok := role.(string); ok && r == requiredRole {
- hasRequiredRole = true
- break
- }
- }
- }
-
- if !hasRequiredRole {
- log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
- ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
- return
- }
- }
-
- // Add claims to context
- ctx := r.Context()
- ctx = context.WithValue(ctx, "user_id", userID)
- ctx = context.WithValue(ctx, "claims", claims)
-
- log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
- }
-}
-
-// responseWriter is a custom response writer that captures the status code
-type responseWriter struct {
- http.ResponseWriter
- status int
-}
-
-func (rw *responseWriter) WriteHeader(code int) {
- rw.status = code
- rw.ResponseWriter.WriteHeader(code)
-}
-
-// GetUserID gets the user ID from the request context
-func GetUserID(r *http.Request) (string, error) {
- userID, ok := r.Context().Value("user_id").(string)
- if !ok {
- return "", fmt.Errorf("user ID not found in context")
- }
- return userID, nil
-}
diff --git a/miniprogram-example/app.js b/miniprogram-example/app.js
deleted file mode 100644
index 34eba08..0000000
--- a/miniprogram-example/app.js
+++ /dev/null
@@ -1,50 +0,0 @@
-// app.js
-App({
- onLaunch() {
- // 检查更新
- if (wx.canIUse('getUpdateManager')) {
- const updateManager = wx.getUpdateManager();
- updateManager.onCheckForUpdate(function (res) {
- if (res.hasUpdate) {
- updateManager.onUpdateReady(function () {
- wx.showModal({
- title: '更新提示',
- content: '新版本已经准备好,是否重启应用?',
- success: function (res) {
- if (res.confirm) {
- updateManager.applyUpdate();
- }
- }
- });
- });
-
- updateManager.onUpdateFailed(function () {
- wx.showModal({
- title: '更新提示',
- content: '新版本下载失败,请检查网络后重试',
- showCancel: false
- });
- });
- }
- });
- }
-
- // 获取系统信息
- try {
- const systemInfo = wx.getSystemInfoSync();
- this.globalData.systemInfo = systemInfo;
-
- // 计算安全区域
- const { screenHeight, safeArea } = systemInfo;
- this.globalData.safeAreaBottom = screenHeight - safeArea.bottom;
- } catch (e) {
- console.error('获取系统信息失败', e);
- }
- },
-
- globalData: {
- userInfo: null,
- systemInfo: {},
- safeAreaBottom: 0
- }
-});
\ No newline at end of file
diff --git a/miniprogram-example/app.json b/miniprogram-example/app.json
deleted file mode 100644
index 89f79ae..0000000
--- a/miniprogram-example/app.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "pages": [
- "pages/login/login",
- "pages/otp-list/index",
- "pages/otp-add/index"
- ],
- "window": {
- "backgroundTextStyle": "light",
- "navigationBarBackgroundColor": "#fff",
- "navigationBarTitleText": "OTPM",
- "navigationBarTextStyle": "black",
- "backgroundColor": "#F8F8F8"
- },
- "permission": {
- "scope.camera": {
- "desc": "需要使用相机扫描二维码"
- }
- },
- "usingComponents": {},
- "style": "v2",
- "sitemapLocation": "sitemap.json",
- "lazyCodeLoading": "requiredComponents"
-}
\ No newline at end of file
diff --git a/miniprogram-example/app.wxss b/miniprogram-example/app.wxss
deleted file mode 100644
index fb73b4a..0000000
--- a/miniprogram-example/app.wxss
+++ /dev/null
@@ -1,238 +0,0 @@
-/**app.wxss**/
-page {
- --primary-color: #1890ff;
- --danger-color: #ff4d4f;
- --success-color: #52c41a;
- --warning-color: #faad14;
- --text-color: #333333;
- --text-color-secondary: #666666;
- --text-color-light: #999999;
- --border-color: #e8e8e8;
- --background-color: #f8f8f8;
- --border-radius: 8rpx;
- --safe-area-bottom: env(safe-area-inset-bottom);
-
- font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica,
- Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei',
- sans-serif;
- font-size: 28rpx;
- line-height: 1.5;
- color: var(--text-color);
- background-color: var(--background-color);
-}
-
-/* 清除默认样式 */
-button {
- padding: 0;
- margin: 0;
- background: none;
- border: none;
- text-align: left;
- line-height: inherit;
- overflow: visible;
-}
-
-button::after {
- border: none;
-}
-
-/* 通用样式类 */
-.container {
- min-height: 100vh;
- box-sizing: border-box;
-}
-
-.safe-area-bottom {
- padding-bottom: var(--safe-area-bottom);
-}
-
-.flex-center {
- display: flex;
- align-items: center;
- justify-content: center;
-}
-
-.flex-between {
- display: flex;
- align-items: center;
- justify-content: space-between;
-}
-
-.flex-column {
- display: flex;
- flex-direction: column;
-}
-
-.text-primary {
- color: var(--primary-color);
-}
-
-.text-danger {
- color: var(--danger-color);
-}
-
-.text-success {
- color: var(--success-color);
-}
-
-.text-warning {
- color: var(--warning-color);
-}
-
-.text-secondary {
- color: var(--text-color-secondary);
-}
-
-.text-light {
- color: var(--text-color-light);
-}
-
-.text-center {
- text-align: center;
-}
-
-.text-left {
- text-align: left;
-}
-
-.text-right {
- text-align: right;
-}
-
-.text-bold {
- font-weight: bold;
-}
-
-.text-ellipsis {
- overflow: hidden;
- text-overflow: ellipsis;
- white-space: nowrap;
-}
-
-.bg-white {
- background-color: #ffffff;
-}
-
-.bg-primary {
- background-color: var(--primary-color);
-}
-
-.bg-danger {
- background-color: var(--danger-color);
-}
-
-.bg-success {
- background-color: var(--success-color);
-}
-
-.bg-warning {
- background-color: var(--warning-color);
-}
-
-.shadow {
- box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
-}
-
-.rounded {
- border-radius: var(--border-radius);
-}
-
-.border {
- border: 2rpx solid var(--border-color);
-}
-
-.border-top {
- border-top: 2rpx solid var(--border-color);
-}
-
-.border-bottom {
- border-bottom: 2rpx solid var(--border-color);
-}
-
-/* 动画类 */
-.fade-in {
- animation: fadeIn 0.3s ease-in-out;
-}
-
-.fade-out {
- animation: fadeOut 0.3s ease-in-out;
-}
-
-@keyframes fadeIn {
- from {
- opacity: 0;
- }
- to {
- opacity: 1;
- }
-}
-
-@keyframes fadeOut {
- from {
- opacity: 1;
- }
- to {
- opacity: 0;
- }
-}
-
-/* 间距类 */
-.m-0 { margin: 0; }
-.m-1 { margin: 10rpx; }
-.m-2 { margin: 20rpx; }
-.m-3 { margin: 30rpx; }
-.m-4 { margin: 40rpx; }
-
-.mt-0 { margin-top: 0; }
-.mt-1 { margin-top: 10rpx; }
-.mt-2 { margin-top: 20rpx; }
-.mt-3 { margin-top: 30rpx; }
-.mt-4 { margin-top: 40rpx; }
-
-.mb-0 { margin-bottom: 0; }
-.mb-1 { margin-bottom: 10rpx; }
-.mb-2 { margin-bottom: 20rpx; }
-.mb-3 { margin-bottom: 30rpx; }
-.mb-4 { margin-bottom: 40rpx; }
-
-.ml-0 { margin-left: 0; }
-.ml-1 { margin-left: 10rpx; }
-.ml-2 { margin-left: 20rpx; }
-.ml-3 { margin-left: 30rpx; }
-.ml-4 { margin-left: 40rpx; }
-
-.mr-0 { margin-right: 0; }
-.mr-1 { margin-right: 10rpx; }
-.mr-2 { margin-right: 20rpx; }
-.mr-3 { margin-right: 30rpx; }
-.mr-4 { margin-right: 40rpx; }
-
-.p-0 { padding: 0; }
-.p-1 { padding: 10rpx; }
-.p-2 { padding: 20rpx; }
-.p-3 { padding: 30rpx; }
-.p-4 { padding: 40rpx; }
-
-.pt-0 { padding-top: 0; }
-.pt-1 { padding-top: 10rpx; }
-.pt-2 { padding-top: 20rpx; }
-.pt-3 { padding-top: 30rpx; }
-.pt-4 { padding-top: 40rpx; }
-
-.pb-0 { padding-bottom: 0; }
-.pb-1 { padding-bottom: 10rpx; }
-.pb-2 { padding-bottom: 20rpx; }
-.pb-3 { padding-bottom: 30rpx; }
-.pb-4 { padding-bottom: 40rpx; }
-
-.pl-0 { padding-left: 0; }
-.pl-1 { padding-left: 10rpx; }
-.pl-2 { padding-left: 20rpx; }
-.pl-3 { padding-left: 30rpx; }
-.pl-4 { padding-left: 40rpx; }
-
-.pr-0 { padding-right: 0; }
-.pr-1 { padding-right: 10rpx; }
-.pr-2 { padding-right: 20rpx; }
-.pr-3 { padding-right: 30rpx; }
-.pr-4 { padding-right: 40rpx; }
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.js b/miniprogram-example/pages/login/login.js
deleted file mode 100644
index e4ad173..0000000
--- a/miniprogram-example/pages/login/login.js
+++ /dev/null
@@ -1,48 +0,0 @@
-// login.js
-import { wxLogin } from '../../services/auth';
-
-Page({
- data: {
- loading: false
- },
-
- onLoad() {
- // 页面加载时检查是否已经登录
- const token = wx.getStorageSync('token');
- if (token) {
- this.redirectToHome();
- }
- },
-
- // 处理登录按钮点击
- handleLogin() {
- if (this.data.loading) return;
-
- this.setData({ loading: true });
-
- wxLogin()
- .then(() => {
- wx.showToast({
- title: '登录成功',
- icon: 'success'
- });
- this.redirectToHome();
- })
- .catch(err => {
- wx.showToast({
- title: err.message || '登录失败',
- icon: 'none'
- });
- })
- .finally(() => {
- this.setData({ loading: false });
- });
- },
-
- // 跳转到首页
- redirectToHome() {
- wx.reLaunch({
- url: '/pages/otp-list/index'
- });
- }
-});
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.json b/miniprogram-example/pages/login/login.json
deleted file mode 100644
index 8835af0..0000000
--- a/miniprogram-example/pages/login/login.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "usingComponents": {}
-}
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.wxml b/miniprogram-example/pages/login/login.wxml
deleted file mode 100644
index d73a711..0000000
--- a/miniprogram-example/pages/login/login.wxml
+++ /dev/null
@@ -1,30 +0,0 @@
-
-
-
-
- OTPM 小程序
-
-
-
- 欢迎使用 OTPM
- 一次性密码管理工具
-
-
-
-
- 登录即表示您同意
- 《隐私政策》
-
-
-
\ No newline at end of file
diff --git a/miniprogram-example/pages/login/login.wxss b/miniprogram-example/pages/login/login.wxss
deleted file mode 100644
index 26cece6..0000000
--- a/miniprogram-example/pages/login/login.wxss
+++ /dev/null
@@ -1,97 +0,0 @@
-/* login.wxss */
-.container {
- display: flex;
- flex-direction: column;
- align-items: center;
- justify-content: space-between;
- height: 100vh;
- padding: 60rpx 40rpx;
- box-sizing: border-box;
- background-color: #f8f8f8;
-}
-
-.logo-container {
- display: flex;
- flex-direction: column;
- align-items: center;
- margin-top: 80rpx;
-}
-
-.logo {
- width: 180rpx;
- height: 180rpx;
- margin-bottom: 20rpx;
-}
-
-.app-name {
- font-size: 36rpx;
- font-weight: bold;
- color: #333;
-}
-
-.login-container {
- width: 100%;
- display: flex;
- flex-direction: column;
- align-items: center;
- margin-bottom: 100rpx;
-}
-
-.login-title {
- font-size: 48rpx;
- font-weight: bold;
- color: #333;
- margin-bottom: 20rpx;
-}
-
-.login-subtitle {
- font-size: 28rpx;
- color: #666;
- margin-bottom: 80rpx;
-}
-
-.login-button {
- width: 80%;
- height: 88rpx;
- border-radius: 44rpx;
- font-size: 32rpx;
- display: flex;
- align-items: center;
- justify-content: center;
-}
-
-.login-button.loading {
- background-color: #8cc4ff;
-}
-
-.loading-container {
- display: flex;
- align-items: center;
- justify-content: center;
-}
-
-.loading-icon {
- width: 36rpx;
- height: 36rpx;
- margin-right: 10rpx;
- border: 4rpx solid #ffffff;
- border-radius: 50%;
- border-top-color: transparent;
- animation: spin 1s linear infinite;
-}
-
-@keyframes spin {
- 0% { transform: rotate(0deg); }
- 100% { transform: rotate(360deg); }
-}
-
-.privacy-policy {
- margin-top: 40rpx;
- font-size: 24rpx;
- color: #999;
-}
-
-.policy-link {
- color: #1890ff;
- display: inline;
-}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.js b/miniprogram-example/pages/otp-add/index.js
deleted file mode 100644
index e89c623..0000000
--- a/miniprogram-example/pages/otp-add/index.js
+++ /dev/null
@@ -1,169 +0,0 @@
-// otp-add/index.js
-import { createOTP } from '../../services/otp';
-
-Page({
- data: {
- form: {
- name: '',
- issuer: '',
- secret: '',
- algorithm: 'SHA1',
- digits: 6,
- period: 30
- },
- algorithms: ['SHA1', 'SHA256', 'SHA512'],
- digitOptions: [6, 8],
- periodOptions: [30, 60],
- submitting: false,
- scanMode: false
- },
-
- // 处理输入变化
- handleInputChange(e) {
- const { field } = e.currentTarget.dataset;
- const { value } = e.detail;
-
- this.setData({
- [`form.${field}`]: value
- });
- },
-
- // 处理选择器变化
- handlePickerChange(e) {
- const { field } = e.currentTarget.dataset;
- const { value } = e.detail;
-
- const options = this.data[`${field}Options`] || this.data[field];
- const selectedValue = options[value];
-
- this.setData({
- [`form.${field}`]: selectedValue
- });
- },
-
- // 扫描二维码
- handleScanQRCode() {
- this.setData({ scanMode: true });
-
- wx.scanCode({
- scanType: ['qrCode'],
- success: (res) => {
- try {
- // 解析otpauth://协议的URL
- const url = res.result;
- if (url.startsWith('otpauth://totp/')) {
- const parsedUrl = new URL(url);
- const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠
-
- // 解析路径中的issuer和name
- let issuer = '';
- let name = path;
-
- if (path.includes(':')) {
- const parts = path.split(':');
- issuer = parts[0];
- name = parts[1];
- }
-
- // 从查询参数中获取其他信息
- const secret = parsedUrl.searchParams.get('secret') || '';
- const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1';
- const digits = parseInt(parsedUrl.searchParams.get('digits') || '6');
- const period = parseInt(parsedUrl.searchParams.get('period') || '30');
-
- // 如果查询参数中有issuer,优先使用
- if (parsedUrl.searchParams.get('issuer')) {
- issuer = parsedUrl.searchParams.get('issuer');
- }
-
- this.setData({
- form: {
- name,
- issuer,
- secret,
- algorithm,
- digits,
- period
- }
- });
-
- wx.showToast({
- title: '二维码解析成功',
- icon: 'success'
- });
- } else {
- wx.showToast({
- title: '不支持的二维码格式',
- icon: 'none'
- });
- }
- } catch (err) {
- wx.showToast({
- title: '二维码解析失败',
- icon: 'none'
- });
- }
- },
- fail: () => {
- wx.showToast({
- title: '扫描取消',
- icon: 'none'
- });
- },
- complete: () => {
- this.setData({ scanMode: false });
- }
- });
- },
-
- // 提交表单
- handleSubmit() {
- const { form } = this.data;
-
- // 表单验证
- if (!form.name) {
- wx.showToast({
- title: '请输入名称',
- icon: 'none'
- });
- return;
- }
-
- if (!form.secret) {
- wx.showToast({
- title: '请输入密钥',
- icon: 'none'
- });
- return;
- }
-
- this.setData({ submitting: true });
-
- createOTP(form)
- .then(() => {
- wx.showToast({
- title: '添加成功',
- icon: 'success'
- });
-
- // 返回上一页
- setTimeout(() => {
- wx.navigateBack();
- }, 1500);
- })
- .catch(err => {
- wx.showToast({
- title: err.message || '添加失败',
- icon: 'none'
- });
- })
- .finally(() => {
- this.setData({ submitting: false });
- });
- },
-
- // 取消
- handleCancel() {
- wx.navigateBack();
- }
-});
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.json b/miniprogram-example/pages/otp-add/index.json
deleted file mode 100644
index 8835af0..0000000
--- a/miniprogram-example/pages/otp-add/index.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "usingComponents": {}
-}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.wxml b/miniprogram-example/pages/otp-add/index.wxml
deleted file mode 100644
index e5a2b01..0000000
--- a/miniprogram-example/pages/otp-add/index.wxml
+++ /dev/null
@@ -1,119 +0,0 @@
-
-
-
-
-
-
- 名称 *
-
-
-
-
- 发行方
-
-
-
-
- 密钥 *
-
-
-
- 🔍
-
-
-
-
-
-
-
-
- 算法
-
-
- {{form.algorithm}}
- ▼
-
-
-
-
-
-
- 位数
-
-
- {{form.digits}}
- ▼
-
-
-
-
-
- 周期(秒)
-
-
- {{form.period}}
- ▼
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-add/index.wxss b/miniprogram-example/pages/otp-add/index.wxss
deleted file mode 100644
index 3906a97..0000000
--- a/miniprogram-example/pages/otp-add/index.wxss
+++ /dev/null
@@ -1,176 +0,0 @@
-/* otp-add/index.wxss */
-.container {
- min-height: 100vh;
- background-color: #f8f8f8;
- padding: 0 0 40rpx 0;
-}
-
-.header {
- display: flex;
- justify-content: space-between;
- align-items: center;
- padding: 40rpx 32rpx;
- background-color: #ffffff;
- position: sticky;
- top: 0;
- z-index: 100;
- box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
-}
-
-.title {
- font-size: 36rpx;
- font-weight: bold;
- color: #333333;
-}
-
-.form-container {
- background-color: #ffffff;
- padding: 32rpx;
- margin: 32rpx;
- border-radius: 16rpx;
- box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
-}
-
-.form-group {
- margin-bottom: 32rpx;
-}
-
-.form-row {
- display: flex;
- justify-content: space-between;
-}
-
-.form-group.half {
- width: 48%;
-}
-
-.form-label {
- display: block;
- font-size: 28rpx;
- color: #666666;
- margin-bottom: 12rpx;
-}
-
-.required {
- color: #ff4d4f;
-}
-
-.form-input {
- width: 100%;
- height: 80rpx;
- background-color: #f5f5f5;
- border-radius: 8rpx;
- padding: 0 24rpx;
- font-size: 28rpx;
- color: #333333;
- box-sizing: border-box;
-}
-
-.secret-input-container {
- position: relative;
-}
-
-.scan-button {
- position: absolute;
- right: 20rpx;
- top: 50%;
- transform: translateY(-50%);
- width: 60rpx;
- height: 60rpx;
- display: flex;
- align-items: center;
- justify-content: center;
-}
-
-.scan-icon {
- font-size: 40rpx;
- color: #1890ff;
-}
-
-.scanning-indicator {
- position: absolute;
- right: 20rpx;
- top: 50%;
- transform: translateY(-50%);
- width: 40rpx;
- height: 40rpx;
-}
-
-.scanning-spinner {
- width: 40rpx;
- height: 40rpx;
- border: 4rpx solid #f3f3f3;
- border-top: 4rpx solid #1890ff;
- border-radius: 50%;
- animation: spin 1s linear infinite;
-}
-
-@keyframes spin {
- 0% { transform: rotate(0deg); }
- 100% { transform: rotate(360deg); }
-}
-
-.picker-view {
- width: 100%;
- height: 80rpx;
- background-color: #f5f5f5;
- border-radius: 8rpx;
- padding: 0 24rpx;
- font-size: 28rpx;
- color: #333333;
- display: flex;
- align-items: center;
- justify-content: space-between;
- box-sizing: border-box;
-}
-
-.picker-arrow {
- font-size: 24rpx;
- color: #999999;
-}
-
-.button-group {
- display: flex;
- justify-content: space-between;
- padding: 32rpx;
-}
-
-.cancel-button, .submit-button {
- width: 48%;
- height: 88rpx;
- border-radius: 44rpx;
- font-size: 32rpx;
- display: flex;
- align-items: center;
- justify-content: center;
-}
-
-.cancel-button {
- background-color: #f5f5f5;
- color: #666666;
-}
-
-.submit-button {
- background-color: #1890ff;
- color: #ffffff;
-}
-
-.submit-button.loading {
- background-color: #8cc4ff;
-}
-
-.loading-container {
- display: flex;
- align-items: center;
- justify-content: center;
-}
-
-.loading-icon {
- width: 36rpx;
- height: 36rpx;
- margin-right: 10rpx;
- border: 4rpx solid #ffffff;
- border-radius: 50%;
- border-top-color: transparent;
- animation: spin 1s linear infinite;
-}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.js b/miniprogram-example/pages/otp-list/index.js
deleted file mode 100644
index 9ed6fa1..0000000
--- a/miniprogram-example/pages/otp-list/index.js
+++ /dev/null
@@ -1,213 +0,0 @@
-// otp-list/index.js
-import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp';
-import { checkLoginStatus } from '../../services/auth';
-
-Page({
- data: {
- otpList: [],
- loading: true,
- refreshing: false
- },
-
- onLoad() {
- this.checkLogin();
- },
-
- onShow() {
- // 每次页面显示时刷新OTP列表
- if (!this.data.loading) {
- this.fetchOTPList();
- }
- },
-
- // 下拉刷新
- onPullDownRefresh() {
- this.setData({ refreshing: true });
- this.fetchOTPList().finally(() => {
- wx.stopPullDownRefresh();
- this.setData({ refreshing: false });
- });
- },
-
- // 检查登录状态
- checkLogin() {
- checkLoginStatus().then(isLoggedIn => {
- if (isLoggedIn) {
- this.fetchOTPList();
- } else {
- wx.redirectTo({
- url: '/pages/login/login'
- });
- }
- });
- },
-
- // 获取OTP列表
- fetchOTPList() {
- this.setData({ loading: true });
- return getOTPList()
- .then(res => {
- if (res.data && Array.isArray(res.data)) {
- this.setData({
- otpList: res.data,
- loading: false
- });
-
- // 获取每个OTP的当前验证码
- this.refreshOTPCodes();
- }
- })
- .catch(err => {
- wx.showToast({
- title: '获取OTP列表失败',
- icon: 'none'
- });
- this.setData({ loading: false });
- });
- },
-
- // 刷新所有OTP的验证码
- refreshOTPCodes() {
- const { otpList } = this.data;
-
- // 为每个OTP获取当前验证码
- const promises = otpList.map(otp => {
- return getOTPCode(otp.id)
- .then(res => {
- if (res.data && res.data.code) {
- return {
- id: otp.id,
- code: res.data.code,
- expiresIn: res.data.expires_in || 30
- };
- }
- return null;
- })
- .catch(() => null);
- });
-
- Promise.all(promises).then(results => {
- const updatedList = [...this.data.otpList];
-
- results.forEach(result => {
- if (result) {
- const index = updatedList.findIndex(otp => otp.id === result.id);
- if (index !== -1) {
- updatedList[index] = {
- ...updatedList[index],
- currentCode: result.code,
- expiresIn: result.expiresIn
- };
- }
- }
- });
-
- this.setData({ otpList: updatedList });
-
- // 设置定时器,每秒更新倒计时
- this.startCountdown();
- });
- },
-
- // 开始倒计时
- startCountdown() {
- // 清除之前的定时器
- if (this.countdownTimer) {
- clearInterval(this.countdownTimer);
- }
-
- // 创建新的定时器,每秒更新一次
- this.countdownTimer = setInterval(() => {
- const { otpList } = this.data;
- let needRefresh = false;
-
- const updatedList = otpList.map(otp => {
- if (!otp.countdown) {
- otp.countdown = otp.expiresIn || 30;
- }
-
- otp.countdown -= 1;
-
- // 如果倒计时结束,标记需要刷新
- if (otp.countdown <= 0) {
- needRefresh = true;
- }
-
- return otp;
- });
-
- this.setData({ otpList: updatedList });
-
- // 如果有OTP需要刷新,重新获取验证码
- if (needRefresh) {
- this.refreshOTPCodes();
- }
- }, 1000);
- },
-
- // 添加新的OTP
- handleAddOTP() {
- wx.navigateTo({
- url: '/pages/otp-add/index'
- });
- },
-
- // 编辑OTP
- handleEditOTP(e) {
- const { id } = e.currentTarget.dataset;
- wx.navigateTo({
- url: `/pages/otp-edit/index?id=${id}`
- });
- },
-
- // 删除OTP
- handleDeleteOTP(e) {
- const { id, name } = e.currentTarget.dataset;
-
- wx.showModal({
- title: '确认删除',
- content: `确定要删除 ${name} 吗?`,
- confirmColor: '#ff4d4f',
- success: (res) => {
- if (res.confirm) {
- deleteOTP(id)
- .then(() => {
- wx.showToast({
- title: '删除成功',
- icon: 'success'
- });
- this.fetchOTPList();
- })
- .catch(err => {
- wx.showToast({
- title: '删除失败',
- icon: 'none'
- });
- });
- }
- }
- });
- },
-
- // 复制验证码
- handleCopyCode(e) {
- const { code } = e.currentTarget.dataset;
-
- wx.setClipboardData({
- data: code,
- success: () => {
- wx.showToast({
- title: '验证码已复制',
- icon: 'success'
- });
- }
- });
- },
-
- onUnload() {
- // 页面卸载时清除定时器
- if (this.countdownTimer) {
- clearInterval(this.countdownTimer);
- }
- }
-});
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.json b/miniprogram-example/pages/otp-list/index.json
deleted file mode 100644
index 8835af0..0000000
--- a/miniprogram-example/pages/otp-list/index.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "usingComponents": {}
-}
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.wxml b/miniprogram-example/pages/otp-list/index.wxml
deleted file mode 100644
index 27cb5ea..0000000
--- a/miniprogram-example/pages/otp-list/index.wxml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
- 加载中...
-
-
-
-
-
-
-
-
- {{item.name}}
- {{item.issuer}}
-
-
-
- {{item.currentCode || '******'}}
- 点击复制
-
-
-
-
- {{item.countdown || 0}}s
-
-
-
-
-
- ✎
-
-
- ✕
-
-
-
-
-
-
-
-
- 暂无OTP,点击右上角添加
-
-
-
\ No newline at end of file
diff --git a/miniprogram-example/pages/otp-list/index.wxss b/miniprogram-example/pages/otp-list/index.wxss
deleted file mode 100644
index 40436d5..0000000
--- a/miniprogram-example/pages/otp-list/index.wxss
+++ /dev/null
@@ -1,201 +0,0 @@
-/* otp-list/index.wxss */
-.container {
- min-height: 100vh;
- background-color: #f8f8f8;
- padding: 0 0 40rpx 0;
-}
-
-.header {
- display: flex;
- justify-content: space-between;
- align-items: center;
- padding: 40rpx 32rpx;
- background-color: #ffffff;
- position: sticky;
- top: 0;
- z-index: 100;
- box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
-}
-
-.title {
- font-size: 36rpx;
- font-weight: bold;
- color: #333333;
-}
-
-.add-button {
- width: 64rpx;
- height: 64rpx;
- border-radius: 32rpx;
- background-color: #1890ff;
- display: flex;
- align-items: center;
- justify-content: center;
- box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3);
-}
-
-.add-icon {
- color: #ffffff;
- font-size: 40rpx;
- line-height: 1;
-}
-
-/* 加载状态 */
-.loading-container {
- display: flex;
- flex-direction: column;
- align-items: center;
- justify-content: center;
- padding: 120rpx 0;
-}
-
-.loading-spinner {
- width: 64rpx;
- height: 64rpx;
- border: 6rpx solid #f3f3f3;
- border-top: 6rpx solid #1890ff;
- border-radius: 50%;
- animation: spin 1s linear infinite;
-}
-
-.loading-text {
- margin-top: 20rpx;
- font-size: 28rpx;
- color: #999999;
-}
-
-@keyframes spin {
- 0% { transform: rotate(0deg); }
- 100% { transform: rotate(360deg); }
-}
-
-/* OTP列表 */
-.otp-list {
- padding: 20rpx 32rpx;
-}
-
-.otp-item {
- background-color: #ffffff;
- border-radius: 16rpx;
- padding: 32rpx;
- margin-bottom: 20rpx;
- display: flex;
- justify-content: space-between;
- box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
-}
-
-.otp-info {
- flex: 1;
- margin-right: 20rpx;
-}
-
-.otp-name-row {
- display: flex;
- align-items: center;
- margin-bottom: 16rpx;
-}
-
-.otp-name {
- font-size: 32rpx;
- font-weight: bold;
- color: #333333;
- margin-right: 16rpx;
-}
-
-.otp-issuer {
- font-size: 24rpx;
- color: #666666;
- background-color: #f5f5f5;
- padding: 4rpx 12rpx;
- border-radius: 8rpx;
-}
-
-.otp-code-row {
- display: flex;
- align-items: center;
- margin-bottom: 20rpx;
-}
-
-.otp-code {
- font-size: 44rpx;
- font-family: monospace;
- font-weight: bold;
- color: #1890ff;
- letter-spacing: 4rpx;
- margin-right: 16rpx;
-}
-
-.copy-hint {
- font-size: 24rpx;
- color: #999999;
-}
-
-.otp-countdown {
- position: relative;
- width: 100%;
-}
-
-.countdown-text {
- position: absolute;
- right: 0;
- top: -30rpx;
- font-size: 24rpx;
- color: #999999;
-}
-
-/* OTP操作按钮 */
-.otp-actions {
- display: flex;
- flex-direction: column;
- justify-content: space-between;
-}
-
-.action-button {
- width: 56rpx;
- height: 56rpx;
- border-radius: 28rpx;
- display: flex;
- align-items: center;
- justify-content: center;
- margin-bottom: 16rpx;
-}
-
-.action-button.edit {
- background-color: #f0f7ff;
-}
-
-.action-button.delete {
- background-color: #fff1f0;
-}
-
-.action-icon {
- font-size: 32rpx;
-}
-
-.edit .action-icon {
- color: #1890ff;
-}
-
-.delete .action-icon {
- color: #ff4d4f;
-}
-
-/* 空状态 */
-.empty-state {
- display: flex;
- flex-direction: column;
- align-items: center;
- justify-content: center;
- padding: 120rpx 0;
-}
-
-.empty-image {
- width: 240rpx;
- height: 240rpx;
- margin-bottom: 32rpx;
-}
-
-.empty-text {
- font-size: 28rpx;
- color: #999999;
-}
\ No newline at end of file
diff --git a/miniprogram-example/project.config.json b/miniprogram-example/project.config.json
deleted file mode 100644
index 3fb79ad..0000000
--- a/miniprogram-example/project.config.json
+++ /dev/null
@@ -1,47 +0,0 @@
-{
- "description": "项目配置文件",
- "packOptions": {
- "ignore": [],
- "include": []
- },
- "miniprogramRoot": "",
- "compileType": "miniprogram",
- "projectname": "OTPM",
- "setting": {
- "useCompilerPlugins": [
- "sass"
- ],
- "babelSetting": {
- "ignore": [],
- "disablePlugins": [],
- "outputPath": ""
- },
- "es6": true,
- "enhance": true,
- "minified": true,
- "postcss": true,
- "minifyWXSS": true,
- "minifyWXML": true,
- "uglifyFileName": true,
- "packNpmManually": false,
- "packNpmRelationList": [],
- "ignoreUploadUnusedFiles": true,
- "compileWorklet": false,
- "uploadWithSourceMap": true,
- "localPlugins": false,
- "disableUseStrict": false,
- "condition": false,
- "swc": false,
- "disableSWC": true
- },
- "simulatorType": "wechat",
- "simulatorPluginLibVersion": {},
- "condition": {},
- "srcMiniprogramRoot": "",
- "appid": "wxb6599459668b6b55",
- "libVersion": "2.30.2",
- "editorSetting": {
- "tabIndent": "insertSpaces",
- "tabSize": 2
- }
-}
\ No newline at end of file
diff --git a/miniprogram-example/project.private.config.json b/miniprogram-example/project.private.config.json
deleted file mode 100644
index 6b2a738..0000000
--- a/miniprogram-example/project.private.config.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "libVersion": "3.8.5",
- "projectname": "OTPM",
- "condition": {},
- "setting": {
- "urlCheck": true,
- "coverView": false,
- "lazyloadPlaceholderEnable": false,
- "skylineRenderEnable": false,
- "preloadBackgroundData": false,
- "autoAudits": false,
- "useApiHook": true,
- "useApiHostProcess": true,
- "showShadowRootInWxmlPanel": false,
- "useStaticServer": false,
- "useLanDebug": false,
- "showES6CompileOption": false,
- "compileHotReLoad": true,
- "checkInvalidKey": true,
- "ignoreDevUnusedFiles": true,
- "bigPackageSizeSupport": false
- }
-}
\ No newline at end of file
diff --git a/miniprogram-example/services/auth.js b/miniprogram-example/services/auth.js
deleted file mode 100644
index 4c95fa4..0000000
--- a/miniprogram-example/services/auth.js
+++ /dev/null
@@ -1,84 +0,0 @@
-// auth.js - 认证相关服务
-
-import request from '../utils/request';
-
-/**
- * 微信登录
- * 1. 调用wx.login获取code
- * 2. 发送code到服务端换取token和openid
- * 3. 保存token和openid到本地存储
- */
-export const wxLogin = () => {
- return new Promise((resolve, reject) => {
- wx.login({
- success: (res) => {
- if (res.code) {
- // 发送code到服务端
- request({
- url: '/login',
- method: 'POST',
- data: {
- code: res.code
- }
- }).then(response => {
- // 保存token和openid
- if (response.data && response.data.token && response.data.openid) {
- wx.setStorageSync('token', response.data.token);
- wx.setStorageSync('openid', response.data.openid);
- resolve(response.data);
- } else {
- reject(new Error('登录失败,服务器返回数据格式错误'));
- }
- }).catch(err => {
- reject(err);
- });
- } else {
- reject(new Error('登录失败,获取code失败: ' + res.errMsg));
- }
- },
- fail: (err) => {
- reject(new Error('微信登录失败: ' + err.errMsg));
- }
- });
- });
-};
-
-/**
- * 检查登录状态
- * 1. 检查本地是否有token和openid
- * 2. 如果有,验证token是否有效
- * 3. 如果无效,清除本地存储并返回false
- */
-export const checkLoginStatus = () => {
- return new Promise((resolve, reject) => {
- const token = wx.getStorageSync('token');
- const openid = wx.getStorageSync('openid');
-
- if (!token || !openid) {
- resolve(false);
- return;
- }
-
- // 验证token有效性
- request({
- url: '/verify-token',
- method: 'POST'
- }).then(() => {
- resolve(true);
- }).catch(() => {
- // token无效,清除本地存储
- wx.removeStorageSync('token');
- wx.removeStorageSync('openid');
- resolve(false);
- });
- });
-};
-
-/**
- * 退出登录
- */
-export const logout = () => {
- wx.removeStorageSync('token');
- wx.removeStorageSync('openid');
- return Promise.resolve();
-};
\ No newline at end of file
diff --git a/miniprogram-example/services/otp.js b/miniprogram-example/services/otp.js
deleted file mode 100644
index 9da9925..0000000
--- a/miniprogram-example/services/otp.js
+++ /dev/null
@@ -1,119 +0,0 @@
-// otp.js - OTP相关服务
-
-import request from '../utils/request';
-
-/**
- * 创建新的OTP
- * @param {Object} params - 创建OTP的参数
- * @param {string} params.name - OTP名称
- * @param {string} params.issuer - 发行方
- * @param {string} params.secret - 密钥
- * @param {string} params.algorithm - 算法,默认为SHA1
- * @param {number} params.digits - 位数,默认为6
- * @param {number} params.period - 周期,默认为30秒
- * @returns {Promise} - 返回创建结果
- */
-export const createOTP = (params) => {
- if (!params || !params.secret) {
- return Promise.reject(new Error('缺少必要的参数: secret'));
- }
- return request({
- url: '/otp',
- method: 'POST',
- data: {
- name: params.name || '',
- issuer: params.issuer || '',
- secret: params.secret,
- algorithm: params.algorithm || 'SHA1',
- digits: params.digits || 6,
- period: params.period || 30
- }
- }).catch(err => {
- console.error('创建OTP失败:', err);
- throw new Error('创建OTP失败: ' + (err.message || '未知错误'));
- });
-};
-
-/**
- * 获取用户所有OTP列表
- * @returns {Promise} - 返回OTP列表
- */
-export const getOTPList = () => {
- return request({
- url: '/otp',
- method: 'GET'
- }).catch(err => {
- console.error('获取OTP列表失败:', err);
- throw new Error('获取OTP列表失败: ' + (err.message || '未知错误'));
- });
-};
-
-/**
- * 获取指定OTP的当前验证码
- * @param {string} id - OTP的ID
- * @returns {Promise} - 返回当前验证码
- */
-export const getOTPCode = (id) => {
- if (!id) {
- return Promise.reject(new Error('缺少必要的参数: id'));
- }
- return request({
- url: `/otp/${id}/code`,
- method: 'GET'
- }).catch(err => {
- console.error('获取OTP代码失败:', err);
- throw new Error('获取OTP代码失败: ' + (err.message || '未知错误'));
- });
-};
-
-/**
- * 验证OTP
- * @param {string} id - OTP的ID
- * @param {string} code - 用户输入的验证码
- * @returns {Promise} - 返回验证结果
- */
-export const verifyOTP = (id, code) => {
- if (!id || !code) {
- return Promise.reject(new Error('缺少必要的参数: id或code'));
- }
- return request({
- url: `/otp/${id}/verify`,
- method: 'POST',
- data: { code }
- }).catch(err => {
- console.error('验证OTP失败:', err);
- throw new Error('验证OTP失败: ' + (err.message || '未知错误'));
- });
-};
-
-/**
- * 更新OTP信息
- * @param {string} id - OTP的ID
- * @param {Object} params - 更新的参数
- * @returns {Promise} - 返回更新结果
- */
-export const updateOTP = (id, params) => {
- if (!id || !params) {
- return Promise.reject(new Error('缺少必要的参数: id或params'));
- }
- return request({
- url: `/otp/${id}`,
- method: 'PUT',
- data: params
- }).catch(err => {
- console.error('更新OTP失败:', err);
- throw new Error('更新OTP失败: ' + (err.message || '未知错误'));
- });
-};
-
-/**
- * 删除OTP
- * @param {string} id - OTP的ID
- * @returns {Promise} - 返回删除结果
- */
-export const deleteOTP = (id) => {
- return request({
- url: `/otp/${id}`,
- method: 'DELETE'
- });
-};
\ No newline at end of file
diff --git a/miniprogram-example/utils/request.js b/miniprogram-example/utils/request.js
deleted file mode 100644
index 9ea7e2e..0000000
--- a/miniprogram-example/utils/request.js
+++ /dev/null
@@ -1,58 +0,0 @@
-// request.js - 网络请求工具类
-
-const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名
-
-// 请求拦截器
-const request = (options) => {
- return new Promise((resolve, reject) => {
- const token = wx.getStorageSync('token');
- const header = {
- 'Content-Type': 'application/json',
- ...options.header
- };
-
- // 如果有token,添加到请求头
- if (token) {
- header['Authorization'] = `Bearer ${token}`;
- }
-
- wx.request({
- url: `${BASE_URL}${options.url}`,
- method: options.method || 'GET',
- data: options.data,
- header: header,
- success: (res) => {
- // 处理业务错误
- if (res.data.code !== 0) {
- // token过期,直接清除并跳转登录
- if (res.statusCode === 401) {
- wx.removeStorageSync('token');
- wx.removeStorageSync('openid');
- reject(new Error('登录已过期,请重新登录'));
- return;
- }
- reject(new Error(res.data.message || '请求失败'));
- return;
- }
- resolve(res.data);
- },
- fail: reject
- });
- });
-};
-
-// 刷新token
-const refreshToken = () => {
- return request({
- url: '/refresh-token',
- method: 'POST'
- }).then(res => {
- if (res.data && res.data.token) {
- wx.setStorageSync('token', res.data.token);
- return res.data.token;
- }
- throw new Error('Failed to refresh token');
- });
-};
-
-export default request;
\ No newline at end of file
diff --git a/models/otp.go b/models/otp.go
deleted file mode 100644
index 8eaab88..0000000
--- a/models/otp.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package models
-
-import (
- "context"
- "time"
-)
-
-// OTP represents a TOTP configuration
-type OTP struct {
- ID int64 `json:"id" db:"id"`
- UserID string `json:"user_id" db:"user_id" validate:"required"`
- OpenID string `json:"openid" db:"openid" validate:"required"`
- Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
- Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
- Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
- Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
- Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
- Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
- CreatedAt time.Time `json:"created_at" db:"created_at"`
- UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
-}
-
-// OTPParams represents common OTP parameters used in creation and update
-type OTPParams struct {
- Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
- Issuer string `json:"issuer" validate:"omitempty,issuer"`
- Secret string `json:"secret" validate:"required,otpsecret"`
- Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
- Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
- Period int `json:"period" validate:"omitempty,min=30,max=60"`
-}
-
-// OTPRepository handles OTP data storage
-type OTPRepository struct {
- // Add your database connection or ORM here
-}
-
-// Create creates a new OTP record
-func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
- // Implement database creation logic
- return nil
-}
-
-// FindByID finds an OTP by ID and user ID
-func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
- // Implement database lookup logic
- return nil, nil
-}
-
-// FindAllByUserID finds all OTPs for a user
-func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
- // Implement database query logic
- return nil, nil
-}
-
-// Update updates an existing OTP record
-func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
- // Implement database update logic
- return nil
-}
-
-// Delete deletes an OTP record
-func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
- // Implement database deletion logic
- return nil
-}
diff --git a/models/user.go b/models/user.go
deleted file mode 100644
index f12399b..0000000
--- a/models/user.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package models
-
-import (
- "context"
- "database/sql"
- "fmt"
- "time"
-
- "github.com/jmoiron/sqlx"
-)
-
-// User represents a user in the system
-type User struct {
- ID string `db:"id" json:"id"`
- OpenID string `db:"openid" json:"openid"`
- SessionKey string `db:"session_key" json:"-"`
- CreatedAt time.Time `db:"created_at" json:"created_at"`
- UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
-}
-
-// UserRepository handles user data operations
-type UserRepository struct {
- db *sqlx.DB
-}
-
-// NewUserRepository creates a new UserRepository
-func NewUserRepository(db *sqlx.DB) *UserRepository {
- return &UserRepository{db: db}
-}
-
-// FindByID finds a user by ID
-func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
- var user User
- query := `SELECT * FROM users WHERE id = ?`
- err := r.db.GetContext(ctx, &user, query, id)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, fmt.Errorf("user not found: %w", err)
- }
- return nil, fmt.Errorf("failed to find user: %w", err)
- }
- return &user, nil
-}
-
-// FindByOpenID finds a user by OpenID
-func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) {
- var user User
- query := `SELECT * FROM users WHERE openid = ?`
- err := r.db.GetContext(ctx, &user, query, openID)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil // User not found, but not an error
- }
- return nil, fmt.Errorf("failed to find user: %w", err)
- }
- return &user, nil
-}
-
-// Create creates a new user
-func (r *UserRepository) Create(ctx context.Context, user *User) error {
- query := `
- INSERT INTO users (id, openid, session_key, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?)
- `
- now := time.Now()
- user.CreatedAt = now
- user.UpdatedAt = now
-
- _, err := r.db.ExecContext(
- ctx,
- query,
- user.ID,
- user.OpenID,
- user.SessionKey,
- user.CreatedAt,
- user.UpdatedAt,
- )
- if err != nil {
- return fmt.Errorf("failed to create user: %w", err)
- }
- return nil
-}
-
-// Update updates an existing user
-func (r *UserRepository) Update(ctx context.Context, user *User) error {
- query := `
- UPDATE users
- SET session_key = ?, updated_at = ?
- WHERE id = ?
- `
- user.UpdatedAt = time.Now()
-
- _, err := r.db.ExecContext(
- ctx,
- query,
- user.SessionKey,
- user.UpdatedAt,
- user.ID,
- )
- if err != nil {
- return fmt.Errorf("failed to update user: %w", err)
- }
- return nil
-}
-
-// Delete deletes a user
-func (r *UserRepository) Delete(ctx context.Context, id string) error {
- query := `DELETE FROM users WHERE id = ?`
- _, err := r.db.ExecContext(ctx, query, id)
- if err != nil {
- return fmt.Errorf("failed to delete user: %w", err)
- }
- return nil
-}
diff --git a/otp_api.go b/otp_api.go
new file mode 100644
index 0000000..573ad69
--- /dev/null
+++ b/otp_api.go
@@ -0,0 +1,417 @@
+package main
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "database/sql"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+
+ _ "github.com/lib/pq"
+)
+
+// 加密密钥(32字节AES-256)
+var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
+
+// encryptTokenSecret 加密令牌密钥
+func encryptTokenSecret(secret string) (string, error) {
+ block, err := aes.NewCipher(encryptionKey)
+ if err != nil {
+ return "", err
+ }
+
+ ciphertext := make([]byte, aes.BlockSize+len(secret))
+ iv := ciphertext[:aes.BlockSize]
+ if _, err := io.ReadFull(rand.Reader, iv); err != nil {
+ return "", err
+ }
+
+ stream := cipher.NewCFBEncrypter(block, iv)
+ stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
+
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+// decryptTokenSecret 解密令牌密钥
+func decryptTokenSecret(encrypted string) (string, error) {
+ ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
+ if err != nil {
+ return "", err
+ }
+
+ block, err := aes.NewCipher(encryptionKey)
+ if err != nil {
+ return "", err
+ }
+
+ if len(ciphertext) < aes.BlockSize {
+ return "", fmt.Errorf("ciphertext too short")
+ }
+
+ iv := ciphertext[:aes.BlockSize]
+ ciphertext = ciphertext[aes.BlockSize:]
+
+ stream := cipher.NewCFBDecrypter(block, iv)
+ stream.XORKeyStream(ciphertext, ciphertext)
+
+ return string(ciphertext), nil
+}
+
+// SaveRequest 保存请求的数据结构
+type SaveRequest struct {
+ Tokens []TokenData `json:"tokens"`
+ UserID string `json:"userId"`
+ Timestamp int64 `json:"timestamp"`
+}
+
+// TokenData token数据结构
+type TokenData struct {
+ ID string `json:"id"`
+ Issuer string `json:"issuer"`
+ Account string `json:"account"`
+ Secret string `json:"secret"`
+ Type string `json:"type"`
+ Counter int `json:"counter,omitempty"`
+ Period int `json:"period"`
+ Digits int `json:"digits"`
+ Algo string `json:"algo"`
+}
+
+// SaveResponse 保存响应的数据结构
+type SaveResponse struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ Data struct {
+ ID string `json:"id"`
+ } `json:"data"`
+}
+
+// RecoverResponse 恢复响应的数据结构
+type RecoverResponse struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ Data struct {
+ Tokens []TokenData `json:"tokens"`
+ Timestamp int64 `json:"timestamp"`
+ } `json:"data"`
+}
+
+var db *sql.DB
+
+// InitDB 初始化数据库连接
+func InitDB() error {
+ connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
+ var err error
+ db, err = sql.Open("postgres", connStr)
+ if err != nil {
+ return fmt.Errorf("error opening database: %v", err)
+ }
+
+ if err = db.Ping(); err != nil {
+ return fmt.Errorf("error connecting to the database: %v", err)
+ }
+
+ return nil
+}
+
+// SaveHandler 保存token的接口处理函数
+func SaveHandler(w http.ResponseWriter, r *http.Request) {
+ // 设置CORS头
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
+
+ // 处理OPTIONS请求
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ // 检查请求方法
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // 解析请求
+ var req SaveRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ log.Printf("Error decoding request: %v", err)
+ sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+
+ // 验证请求数据
+ if req.UserID == "" {
+ sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
+ return
+ }
+
+ if len(req.Tokens) == 0 {
+ sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
+ return
+ }
+
+ // 开始数据库事务
+ tx, err := db.Begin()
+ if err != nil {
+ log.Printf("Error starting transaction: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+ defer tx.Rollback()
+
+ // 删除用户现有的tokens
+ _, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
+ if err != nil {
+ log.Printf("Error deleting existing tokens: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+
+ // 插入新的tokens
+ stmt, err := tx.Prepare(`
+ INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+ `)
+ if err != nil {
+ log.Printf("Error preparing statement: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+ defer stmt.Close()
+
+ for _, token := range req.Tokens {
+ // 加密secret
+ encryptedSecret, err := encryptTokenSecret(token.Secret)
+ if err != nil {
+ log.Printf("Error encrypting token secret: %v", err)
+ sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
+ return
+ }
+
+ _, err = stmt.Exec(
+ token.ID,
+ req.UserID,
+ token.Issuer,
+ token.Account,
+ encryptedSecret,
+ token.Type,
+ token.Counter,
+ token.Period,
+ token.Digits,
+ token.Algo,
+ req.Timestamp,
+ )
+ if err != nil {
+ log.Printf("Error inserting token: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+ }
+
+ // 提交事务
+ if err = tx.Commit(); err != nil {
+ log.Printf("Error committing transaction: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+
+ // 返回成功响应
+ resp := SaveResponse{
+ Success: true,
+ Message: "Tokens saved successfully",
+ }
+ resp.Data.ID = req.UserID
+
+ sendJSONResponse(w, resp, http.StatusOK)
+}
+
+// RecoverHandler 恢复token的接口处理函数
+func RecoverHandler(w http.ResponseWriter, r *http.Request) {
+ // 设置CORS头
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
+
+ // 处理OPTIONS请求
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ // 检查请求方法
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // 解析请求体
+ var req struct {
+ UserID string `json:"userId"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ log.Printf("Error decoding request: %v", err)
+ sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+
+ // 验证用户ID
+ if req.UserID == "" {
+ sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
+ return
+ }
+
+ // 查询数据库
+ rows, err := db.Query(`
+ SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
+ FROM tokens
+ WHERE user_id = $1
+ ORDER BY timestamp DESC
+ `, req.UserID)
+ if err != nil {
+ log.Printf("Error querying database: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+ defer rows.Close()
+
+ // 读取查询结果
+ var tokens []TokenData
+ var timestamp int64
+ for rows.Next() {
+ var token TokenData
+ var encryptedSecret string
+ err := rows.Scan(
+ &token.ID,
+ &token.Issuer,
+ &token.Account,
+ &encryptedSecret,
+ &token.Type,
+ &token.Counter,
+ &token.Period,
+ &token.Digits,
+ &token.Algo,
+ ×tamp,
+ )
+ if err != nil {
+ log.Printf("Error scanning row: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+
+ // 解密secret
+ token.Secret, err = decryptTokenSecret(encryptedSecret)
+ if err != nil {
+ log.Printf("Error decrypting token secret: %v", err)
+ sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
+ return
+ }
+ tokens = append(tokens, token)
+ }
+
+ if err = rows.Err(); err != nil {
+ log.Printf("Error iterating rows: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+
+ // 返回响应
+ resp := RecoverResponse{
+ Success: true,
+ Message: "Tokens recovered successfully",
+ }
+ resp.Data.Tokens = tokens
+ resp.Data.Timestamp = timestamp
+
+ sendJSONResponse(w, resp, http.StatusOK)
+}
+
+// sendErrorResponse 发送错误响应
+func sendErrorResponse(w http.ResponseWriter, message string, status int) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(status)
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": false,
+ "message": message,
+ })
+}
+
+// DeleteTokenHandler 删除单个token的接口处理函数
+func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
+ // 设置CORS头
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
+
+ // 处理OPTIONS请求
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ // 检查请求方法
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // 解析请求
+ var req struct {
+ UserID string `json:"userId"`
+ TokenID string `json:"tokenId"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ log.Printf("Error decoding request: %v", err)
+ sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+
+ // 验证请求数据
+ if req.UserID == "" || req.TokenID == "" {
+ sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
+ return
+ }
+
+ // 执行删除操作
+ result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
+ if err != nil {
+ log.Printf("Error deleting token: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+
+ // 检查是否真的删除了记录
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ log.Printf("Error getting rows affected: %v", err)
+ sendErrorResponse(w, "Database error", http.StatusInternalServerError)
+ return
+ }
+
+ if rowsAffected == 0 {
+ sendErrorResponse(w, "Token not found", http.StatusNotFound)
+ return
+ }
+
+ // 返回成功响应
+ sendJSONResponse(w, map[string]interface{}{
+ "success": true,
+ "message": "Token deleted successfully",
+ }, http.StatusOK)
+}
+
+// sendJSONResponse 发送JSON响应
+func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(status)
+ if err := json.NewEncoder(w).Encode(data); err != nil {
+ log.Printf("Error encoding response: %v", err)
+ sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
diff --git a/otp_api_test.go b/otp_api_test.go
new file mode 100644
index 0000000..7c366d9
--- /dev/null
+++ b/otp_api_test.go
@@ -0,0 +1,111 @@
+package main
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestSaveHandler(t *testing.T) {
+ // 创建测试服务器
+ srv := httptest.NewServer(http.HandlerFunc(SaveHandler))
+ defer srv.Close()
+
+ // 准备测试数据
+ testData := SaveRequest{
+ UserID: "test_user_123",
+ Tokens: []TokenData{
+ {
+ Issuer: "TestOrg",
+ Account: "user@test.com",
+ Secret: "JBSWY3DPEHPK3PXP",
+ Type: "totp",
+ Period: 30,
+ Digits: 6,
+ Algo: "SHA1",
+ },
+ },
+ }
+
+ // 序列化请求体
+ body, _ := json.Marshal(testData)
+
+ // 发送请求
+ resp, err := http.Post(srv.URL, "application/json", bytes.NewBuffer(body))
+ if err != nil {
+ t.Fatalf("Error making request to server: %v\n", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态码
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, resp.StatusCode)
+ }
+
+ // 解析响应
+ var saveResp SaveResponse
+ err = json.NewDecoder(resp.Body).Decode(&saveResp)
+ if err != nil {
+ t.Errorf("Error decoding response: %v\n", err)
+ }
+
+ // 验证响应数据
+ if !saveResp.Success {
+ t.Errorf("Expected success to be true, got false\n")
+ }
+ if saveResp.Message != "Tokens saved successfully" {
+ t.Errorf("Expected message to be 'Tokens saved successfully', got '%s'\n", saveResp.Message)
+ }
+}
+
+func TestRecoverHandler(t *testing.T) {
+ // 创建测试服务器
+ srv := httptest.NewServer(http.HandlerFunc(RecoverHandler))
+ defer srv.Close()
+
+ // 发送请求(没有user_id参数)
+ resp, err := http.Get(srv.URL)
+ if err != nil {
+ t.Fatalf("Error making request to server: %v\n", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态码(应该返回错误)
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Errorf("Expected status: %d, got: %d\n", http.StatusBadRequest, resp.StatusCode)
+ }
+
+ // 发送带user_id的请求
+ urlWithID := fmt.Sprintf("%s?user_id=test_user_123", srv.URL)
+ respWithID, err := http.Get(urlWithID)
+ if err != nil {
+ t.Fatalf("Error making request to server: %v\n", err)
+ }
+ defer respWithID.Body.Close()
+
+ // 检查响应状态码
+ if respWithID.StatusCode != http.StatusOK {
+ t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, respWithID.StatusCode)
+ }
+
+ // 解析响应
+ var recoverResp RecoverResponse
+ err = json.NewDecoder(respWithID.Body).Decode(&recoverResp)
+ if err != nil {
+ t.Errorf("Error decoding response: %v\n", err)
+ }
+
+ // 验证响应数据
+ if !recoverResp.Success {
+ t.Errorf("Expected success to be true, got false\n")
+ }
+ if recoverResp.Message != "Tokens recovered successfully" {
+ t.Errorf("Expected message to be 'Tokens recovered successfully', got '%s'\n", recoverResp.Message)
+ }
+ if len(recoverResp.Tokens) != 1 {
+ t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Tokens))
+ }
+}
diff --git a/security/security.go b/security/security.go
deleted file mode 100644
index 7b69b81..0000000
--- a/security/security.go
+++ /dev/null
@@ -1,332 +0,0 @@
-package security
-
-import (
- "context"
- "crypto/rand"
- "crypto/subtle"
- "encoding/base64"
- "fmt"
- "net/http"
- "strings"
- "time"
-
- "golang.org/x/crypto/argon2"
-)
-
-// SecurityService provides security functionality
-type SecurityService struct {
- config *Config
-}
-
-// Config represents security configuration
-type Config struct {
- // CSRF protection
- CSRFTokenLength int
- CSRFTokenExpiry time.Duration
- CSRFCookieName string
- CSRFHeaderName string
- CSRFCookieSecure bool
- CSRFCookieHTTPOnly bool
- CSRFCookieSameSite http.SameSite
-
- // Rate limiting
- RateLimitRequests int
- RateLimitWindow time.Duration
-
- // Password hashing
- Argon2Time uint32
- Argon2Memory uint32
- Argon2Threads uint8
- Argon2KeyLen uint32
- Argon2SaltLen uint32
-}
-
-// DefaultConfig returns the default security configuration
-func DefaultConfig() *Config {
- return &Config{
- // CSRF protection
- CSRFTokenLength: 32,
- CSRFTokenExpiry: 24 * time.Hour,
- CSRFCookieName: "csrf_token",
- CSRFHeaderName: "X-CSRF-Token",
- CSRFCookieSecure: true,
- CSRFCookieHTTPOnly: true,
- CSRFCookieSameSite: http.SameSiteStrictMode,
-
- // Rate limiting
- RateLimitRequests: 100,
- RateLimitWindow: time.Minute,
-
- // Password hashing
- Argon2Time: 1,
- Argon2Memory: 64 * 1024,
- Argon2Threads: 4,
- Argon2KeyLen: 32,
- Argon2SaltLen: 16,
- }
-}
-
-// NewSecurityService creates a new SecurityService
-func NewSecurityService(config *Config) *SecurityService {
- if config == nil {
- config = DefaultConfig()
- }
- return &SecurityService{
- config: config,
- }
-}
-
-// GenerateCSRFToken generates a CSRF token
-func (s *SecurityService) GenerateCSRFToken() (string, error) {
- // Generate random bytes
- bytes := make([]byte, s.config.CSRFTokenLength)
- if _, err := rand.Read(bytes); err != nil {
- return "", fmt.Errorf("failed to generate random bytes: %w", err)
- }
-
- // Encode as base64
- token := base64.StdEncoding.EncodeToString(bytes)
- return token, nil
-}
-
-// SetCSRFCookie sets a CSRF cookie
-func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
- http.SetCookie(w, &http.Cookie{
- Name: s.config.CSRFCookieName,
- Value: token,
- Path: "/",
- Expires: time.Now().Add(s.config.CSRFTokenExpiry),
- Secure: s.config.CSRFCookieSecure,
- HttpOnly: s.config.CSRFCookieHTTPOnly,
- SameSite: s.config.CSRFCookieSameSite,
- })
-}
-
-// ValidateCSRFToken validates a CSRF token
-func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
- // Get token from cookie
- cookie, err := r.Cookie(s.config.CSRFCookieName)
- if err != nil {
- return false
- }
- cookieToken := cookie.Value
-
- // Get token from header
- headerToken := r.Header.Get(s.config.CSRFHeaderName)
-
- // Compare tokens
- return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
-}
-
-// CSRFMiddleware creates a middleware that validates CSRF tokens
-func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Skip for GET, HEAD, OPTIONS, TRACE
- if r.Method == http.MethodGet ||
- r.Method == http.MethodHead ||
- r.Method == http.MethodOptions ||
- r.Method == http.MethodTrace {
- next.ServeHTTP(w, r)
- return
- }
-
- // Validate CSRF token
- if !s.ValidateCSRFToken(r) {
- http.Error(w, "Invalid CSRF token", http.StatusForbidden)
- return
- }
-
- next.ServeHTTP(w, r)
- })
-}
-
-// RateLimiter represents a rate limiter
-type RateLimiter struct {
- requests map[string][]time.Time
- config *Config
-}
-
-// NewRateLimiter creates a new RateLimiter
-func NewRateLimiter(config *Config) *RateLimiter {
- return &RateLimiter{
- requests: make(map[string][]time.Time),
- config: config,
- }
-}
-
-// Allow checks if a request is allowed
-func (r *RateLimiter) Allow(key string) bool {
- now := time.Now()
- windowStart := now.Add(-r.config.RateLimitWindow)
-
- // Get requests for key
- requests := r.requests[key]
-
- // Filter out old requests
- var newRequests []time.Time
- for _, t := range requests {
- if t.After(windowStart) {
- newRequests = append(newRequests, t)
- }
- }
-
- // Check if rate limit is exceeded
- if len(newRequests) >= r.config.RateLimitRequests {
- return false
- }
-
- // Add current request
- newRequests = append(newRequests, now)
- r.requests[key] = newRequests
-
- return true
-}
-
-// RateLimitMiddleware creates a middleware that limits request rate
-func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Get client IP
- ip := getClientIP(r)
-
- // Check if request is allowed
- if !limiter.Allow(ip) {
- http.Error(w, "Too many requests", http.StatusTooManyRequests)
- return
- }
-
- next.ServeHTTP(w, r)
- })
- }
-}
-
-// getClientIP gets the client IP address
-func getClientIP(r *http.Request) string {
- // Check X-Forwarded-For header
- xForwardedFor := r.Header.Get("X-Forwarded-For")
- if xForwardedFor != "" {
- // X-Forwarded-For can contain multiple IPs, use the first one
- ips := strings.Split(xForwardedFor, ",")
- return strings.TrimSpace(ips[0])
- }
-
- // Check X-Real-IP header
- xRealIP := r.Header.Get("X-Real-IP")
- if xRealIP != "" {
- return xRealIP
- }
-
- // Use RemoteAddr
- return r.RemoteAddr
-}
-
-// HashPassword hashes a password using Argon2
-func (s *SecurityService) HashPassword(password string) (string, error) {
- // Generate salt
- salt := make([]byte, s.config.Argon2SaltLen)
- if _, err := rand.Read(salt); err != nil {
- return "", fmt.Errorf("failed to generate salt: %w", err)
- }
-
- // Hash password
- hash := argon2.IDKey(
- []byte(password),
- salt,
- s.config.Argon2Time,
- s.config.Argon2Memory,
- s.config.Argon2Threads,
- s.config.Argon2KeyLen,
- )
-
- // Encode as base64
- saltBase64 := base64.StdEncoding.EncodeToString(salt)
- hashBase64 := base64.StdEncoding.EncodeToString(hash)
-
- // Format as $argon2id$v=19$m=65536,t=1,p=4$$
- return fmt.Sprintf(
- "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
- s.config.Argon2Memory,
- s.config.Argon2Time,
- s.config.Argon2Threads,
- saltBase64,
- hashBase64,
- ), nil
-}
-
-// VerifyPassword verifies a password against a hash
-func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
- // Parse encoded hash
- parts := strings.Split(encodedHash, "$")
- if len(parts) != 6 {
- return false, fmt.Errorf("invalid hash format")
- }
-
- // Extract parameters
- if parts[1] != "argon2id" {
- return false, fmt.Errorf("unsupported hash algorithm")
- }
-
- var memory uint32
- var time uint32
- var threads uint8
- _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
- if err != nil {
- return false, fmt.Errorf("failed to parse hash parameters: %w", err)
- }
-
- // Decode salt and hash
- salt, err := base64.StdEncoding.DecodeString(parts[4])
- if err != nil {
- return false, fmt.Errorf("failed to decode salt: %w", err)
- }
-
- hash, err := base64.StdEncoding.DecodeString(parts[5])
- if err != nil {
- return false, fmt.Errorf("failed to decode hash: %w", err)
- }
-
- // Hash password with same parameters
- newHash := argon2.IDKey(
- []byte(password),
- salt,
- time,
- memory,
- threads,
- uint32(len(hash)),
- )
-
- // Compare hashes
- return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
-}
-
-// SecureHeadersMiddleware adds security headers to responses
-func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Add security headers
- w.Header().Set("X-Content-Type-Options", "nosniff")
- w.Header().Set("X-Frame-Options", "DENY")
- w.Header().Set("X-XSS-Protection", "1; mode=block")
- w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
- w.Header().Set("Content-Security-Policy", "default-src 'self'")
- w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
-
- next.ServeHTTP(w, r)
- })
-}
-
-// contextKey is a type for context keys
-type contextKey string
-
-// userIDKey is the key for user ID in context
-const userIDKey = contextKey("user_id")
-
-// WithUserID adds a user ID to the context
-func WithUserID(ctx context.Context, userID string) context.Context {
- return context.WithValue(ctx, userIDKey, userID)
-}
-
-// GetUserID gets the user ID from the context
-func GetUserID(ctx context.Context) (string, bool) {
- userID, ok := ctx.Value(userIDKey).(string)
- return userID, ok
-}
diff --git a/server/server.go b/server/server.go
deleted file mode 100644
index a10e990..0000000
--- a/server/server.go
+++ /dev/null
@@ -1,190 +0,0 @@
-package server
-
-import (
- "context"
- "fmt"
- "log"
- "net/http"
- "os"
- "os/signal"
- "runtime"
- "syscall"
- "time"
-
- "otpm/config"
- "otpm/middleware"
-
- "github.com/julienschmidt/httprouter"
-)
-
-// Server represents the HTTP server
-type Server struct {
- server *http.Server
- router *httprouter.Router
- config *config.Config
-}
-
-// New creates a new server
-func New(cfg *config.Config) *Server {
- router := httprouter.New()
-
- server := &http.Server{
- Addr: fmt.Sprintf(":%d", cfg.Server.Port),
- Handler: router,
- ReadTimeout: cfg.Server.ReadTimeout,
- WriteTimeout: cfg.Server.WriteTimeout,
- IdleTimeout: 120 * time.Second,
- }
-
- return &Server{
- server: server,
- router: router,
- config: cfg,
- }
-}
-
-// Start starts the server
-func (s *Server) Start() error {
- // Apply global middleware in correct order with enhanced error handling
- var handler http.Handler = s.router
-
- // Logger should be first to capture all request details
- handler = middleware.Logger(handler)
-
- // CORS next to handle pre-flight requests
- handler = middleware.CORS(handler)
-
- // Then Timeout to enforce request deadlines
- handler = middleware.Timeout(s.config.Server.Timeout)(handler)
-
- // Recover should be outermost to catch any panics
- handler = middleware.Recover(handler)
-
- s.server.Handler = handler
-
- // Log server configuration at startup
- log.Printf("Server configuration:\n"+
- "Address: %s\n"+
- "Read Timeout: %v\n"+
- "Write Timeout: %v\n"+
- "Idle Timeout: %v\n"+
- "Request Timeout: %v",
- s.server.Addr,
- s.server.ReadTimeout,
- s.server.WriteTimeout,
- s.server.IdleTimeout,
- s.config.Server.Timeout,
- )
-
- // Start server in a goroutine
- serverErr := make(chan error, 1)
- go func() {
- log.Printf("Server starting on %s", s.server.Addr)
- if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- serverErr <- fmt.Errorf("server error: %w", err)
- }
- }()
-
- // Wait for interrupt signal or server error
- quit := make(chan os.Signal, 1)
- signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
-
- select {
- case err := <-serverErr:
- return err
- case <-quit:
- return s.Shutdown()
- }
-}
-
-// Shutdown gracefully stops the server
-func (s *Server) Shutdown() error {
- log.Println("Shutting down server...")
-
- ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout)
- defer cancel()
-
- if err := s.server.Shutdown(ctx); err != nil {
- return fmt.Errorf("graceful shutdown failed: %w", err)
- }
-
- log.Println("Server stopped gracefully")
- return nil
-}
-
-// Router returns the router
-func (s *Server) Router() *httprouter.Router {
- return s.router
-}
-
-// RegisterRoutes registers all routes
-func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
- for pattern, handler := range routes {
- s.router.Handle("GET", pattern, handler)
- s.router.Handle("POST", pattern, handler)
- s.router.Handle("PUT", pattern, handler)
- s.router.Handle("DELETE", pattern, handler)
- }
-}
-
-// RegisterAuthRoutes registers routes that require authentication
-func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
- for pattern, handler := range routes {
- // Apply authentication middleware
- authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
- // Convert httprouter.Handle to http.HandlerFunc for middleware
- wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Store params in request context
- ctx := context.WithValue(r.Context(), "params", ps)
- handler(w, r.WithContext(ctx), ps)
- })
-
- // Apply auth middleware
- middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
- }
-
- s.router.Handle("GET", pattern, authHandler)
- s.router.Handle("POST", pattern, authHandler)
- s.router.Handle("PUT", pattern, authHandler)
- s.router.Handle("DELETE", pattern, authHandler)
- }
-}
-
-// RegisterHealthCheck registers an enhanced health check endpoint
-func (s *Server) RegisterHealthCheck() {
- s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
- response := map[string]interface{}{
- "status": "ok",
- "timestamp": time.Now().Format(time.RFC3339),
- "version": "1.0.0", // Hardcoded version instead of from config
- "system": map[string]interface{}{
- "goroutines": runtime.NumGoroutine(),
- "memory": getMemoryUsage(),
- },
- }
-
- // Add database status if configured
- if s.config.Database.DSN != "" {
- dbStatus := "ok"
- response["database"] = dbStatus
- }
-
- middleware.SuccessResponse(w, response)
- })
-}
-
-// getMemoryUsage returns current memory usage in MB
-func getMemoryUsage() map[string]interface{} {
- var m runtime.MemStats
- runtime.ReadMemStats(&m)
- return map[string]interface{}{
- "alloc_mb": bToMb(m.Alloc),
- "total_alloc_mb": bToMb(m.TotalAlloc),
- "sys_mb": bToMb(m.Sys),
- "num_gc": m.NumGC,
- }
-}
-
-func bToMb(b uint64) float64 {
- return float64(b) / 1024 / 1024
-}
diff --git a/services/auth.go b/services/auth.go
deleted file mode 100644
index 9ad2a85..0000000
--- a/services/auth.go
+++ /dev/null
@@ -1,230 +0,0 @@
-package services
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "net/http"
- "time"
-
- "github.com/golang-jwt/jwt"
- "github.com/google/uuid"
-
- "otpm/config"
- "otpm/models"
-)
-
-// WeChatCode2SessionResponse represents the response from WeChat code2session API
-type WeChatCode2SessionResponse struct {
- OpenID string `json:"openid"`
- SessionKey string `json:"session_key"`
- UnionID string `json:"unionid"`
- ErrCode int `json:"errcode"`
- ErrMsg string `json:"errmsg"`
-}
-
-// AuthService handles authentication related operations
-type AuthService struct {
- config *config.Config
- userRepo *models.UserRepository
- httpClient *http.Client
-}
-
-// NewAuthService creates a new AuthService
-func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
- return &AuthService{
- config: cfg,
- userRepo: userRepo,
- httpClient: &http.Client{
- Timeout: 10 * time.Second,
- },
- }
-}
-
-// LoginWithWeChatCode handles WeChat login
-func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
- start := time.Now()
-
- // Get OpenID and SessionKey from WeChat
- sessionInfo, err := s.getWeChatSession(code)
- if err != nil {
- log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
- return "", fmt.Errorf("failed to get WeChat session: %w", err)
- }
- log.Printf("WeChat session obtained for code %s (took %v)",
- maskCode(code), time.Since(start))
-
- // Find or create user
- user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
- if err != nil {
- log.Printf("User lookup failed for OpenID %s: %v",
- maskOpenID(sessionInfo.OpenID), err)
- return "", fmt.Errorf("failed to find user: %w", err)
- }
-
- if user == nil {
- // Create new user
- user = &models.User{
- ID: uuid.New().String(),
- OpenID: sessionInfo.OpenID,
- SessionKey: sessionInfo.SessionKey,
- }
- if err := s.userRepo.Create(ctx, user); err != nil {
- log.Printf("User creation failed for OpenID %s: %v",
- maskOpenID(sessionInfo.OpenID), err)
- return "", fmt.Errorf("failed to create user: %w", err)
- }
- log.Printf("New user created with ID %s for OpenID %s",
- user.ID, maskOpenID(sessionInfo.OpenID))
- } else {
- // Update session key
- user.SessionKey = sessionInfo.SessionKey
- if err := s.userRepo.Update(ctx, user); err != nil {
- log.Printf("User update failed for ID %s: %v", user.ID, err)
- return "", fmt.Errorf("failed to update user: %w", err)
- }
- log.Printf("User %s session key updated", user.ID)
- }
-
- // Generate JWT token
- token, err := s.generateToken(user)
- if err != nil {
- log.Printf("Token generation failed for user %s: %v", user.ID, err)
- return "", fmt.Errorf("failed to generate token: %w", err)
- }
-
- log.Printf("WeChat login completed for user %s (total time %v)",
- user.ID, time.Since(start))
- return token, nil
-}
-
-// maskCode masks sensitive parts of WeChat code for logging
-func maskCode(code string) string {
- if len(code) < 8 {
- return "****"
- }
- return code[:2] + "****" + code[len(code)-2:]
-}
-
-// maskOpenID masks sensitive parts of OpenID for logging
-func maskOpenID(openID string) string {
- if len(openID) < 8 {
- return "****"
- }
- return openID[:2] + "****" + openID[len(openID)-2:]
-}
-
-// getWeChatSession calls WeChat's code2session API
-func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
- url := fmt.Sprintf(
- "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
- s.config.WeChat.AppID,
- s.config.WeChat.AppSecret,
- code,
- )
-
- resp, err := s.httpClient.Get(url)
- if err != nil {
- return nil, fmt.Errorf("failed to call WeChat API: %w", err)
- }
- defer resp.Body.Close()
-
- var result WeChatCode2SessionResponse
- if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
- }
-
- if result.ErrCode != 0 {
- return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
- }
-
- return &result, nil
-}
-
-// generateToken generates a JWT token for a user
-func (s *AuthService) generateToken(user *models.User) (string, error) {
- now := time.Now()
- claims := jwt.MapClaims{
- "user_id": user.ID,
- "exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
- "iat": now.Unix(),
- "iss": s.config.JWT.Issuer,
- "aud": s.config.JWT.Audience,
- "token_id": uuid.New().String(), // Unique token ID for tracking
- }
-
- // Use stronger signing method
- token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
- signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
- if err != nil {
- return "", fmt.Errorf("failed to sign token: %w", err)
- }
-
- log.Printf("Token generated for user %s (expires at %v)",
- user.ID, now.Add(s.config.JWT.ExpireDelta))
- return signedToken, nil
-}
-
-// ValidateToken validates a JWT token with additional checks
-func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
- token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
- // Verify signing method
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
- }
- return []byte(s.config.JWT.Secret), nil
- })
-
- if err != nil {
- if ve, ok := err.(*jwt.ValidationError); ok {
- switch {
- case ve.Errors&jwt.ValidationErrorMalformed != 0:
- return nil, fmt.Errorf("malformed token")
- case ve.Errors&jwt.ValidationErrorExpired != 0:
- return nil, fmt.Errorf("token expired")
- case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
- return nil, fmt.Errorf("token not active yet")
- default:
- return nil, fmt.Errorf("token validation error: %w", err)
- }
- }
- return nil, fmt.Errorf("failed to parse token: %w", err)
- }
-
- // Additional claims validation
- if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
- // Check issuer
- if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
- return nil, fmt.Errorf("invalid token issuer")
- }
- // Check audience
- if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
- return nil, fmt.Errorf("invalid token audience")
- }
- } else {
- return nil, fmt.Errorf("invalid token claims")
- }
-
- return token, nil
-}
-
-// GetUserFromToken gets user information from a JWT token
-func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
- claims, ok := token.Claims.(jwt.MapClaims)
- if !ok {
- return nil, fmt.Errorf("invalid token claims")
- }
-
- userID, ok := claims["user_id"].(string)
- if !ok {
- return nil, fmt.Errorf("user_id not found in token")
- }
-
- user, err := s.userRepo.FindByID(ctx, userID)
- if err != nil {
- return nil, fmt.Errorf("failed to find user: %w", err)
- }
-
- return user, nil
-}
diff --git a/services/otp.go b/services/otp.go
deleted file mode 100644
index c345bea..0000000
--- a/services/otp.go
+++ /dev/null
@@ -1,358 +0,0 @@
-package services
-
-import (
- "context"
- "crypto/hmac"
- "crypto/sha1"
- "crypto/sha256"
- "crypto/sha512"
- "encoding/base32"
- "encoding/binary"
- "fmt"
- "hash"
- "log"
- "strings"
- "time"
-
- "otpm/models"
-
- "github.com/google/uuid"
-)
-
-// OTPService handles OTP related operations
-type OTPService struct {
- otpRepo *models.OTPRepository
-}
-
-// NewOTPService creates a new OTPService
-func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
- return &OTPService{
- otpRepo: otpRepo,
- }
-}
-
-// CreateOTP creates a new OTP with performance monitoring and logging
-func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
- start := time.Now()
-
- // Validate input
- if err := s.validateOTPInput(input); err != nil {
- log.Printf("OTP validation failed for user %s: %v", userID, err)
- return nil, err
- }
-
- // Clean and standardize secret
- secret := cleanSecret(input.Secret)
-
- // Set defaults for optional fields
- algorithm := strings.ToUpper(input.Algorithm)
- if algorithm == "" {
- algorithm = "SHA1"
- }
-
- digits := input.Digits
- if digits == 0 {
- digits = 6
- }
-
- period := input.Period
- if period == 0 {
- period = 30
- }
-
- // Create OTP
- otp := &models.OTP{
- ID: uuid.New().String(),
- UserID: userID,
- Name: input.Name,
- Issuer: input.Issuer,
- Secret: secret,
- Algorithm: algorithm,
- Digits: digits,
- Period: period,
- }
-
- if err := s.otpRepo.Create(ctx, otp); err != nil {
- log.Printf("Failed to create OTP for user %s: %v", userID, err)
- return nil, fmt.Errorf("failed to create OTP: %w", err)
- }
-
- // Log successful creation (without exposing secret)
- log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
- otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
-
- return otp, nil
-}
-
-// GetOTP gets an OTP by ID
-func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
- otp, err := s.otpRepo.FindByID(ctx, id, userID)
- if err != nil {
- return nil, fmt.Errorf("failed to get OTP: %w", err)
- }
- return otp, nil
-}
-
-// ListOTPs lists all OTPs for a user
-func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
- otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
- if err != nil {
- return nil, fmt.Errorf("failed to list OTPs: %w", err)
- }
- return otps, nil
-}
-
-// UpdateOTP updates an OTP
-func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
- // Get existing OTP
- otp, err := s.otpRepo.FindByID(ctx, id, userID)
- if err != nil {
- return nil, fmt.Errorf("failed to get OTP: %w", err)
- }
-
- // Update fields
- if input.Name != "" {
- otp.Name = input.Name
- }
- if input.Issuer != "" {
- otp.Issuer = input.Issuer
- }
- if input.Algorithm != "" {
- otp.Algorithm = strings.ToUpper(input.Algorithm)
- }
- if input.Digits > 0 {
- otp.Digits = input.Digits
- }
- if input.Period > 0 {
- otp.Period = input.Period
- }
-
- // Validate updated OTP
- if err := s.validateOTPInput(models.OTPParams{
- Name: otp.Name,
- Issuer: otp.Issuer,
- Secret: otp.Secret,
- Algorithm: otp.Algorithm,
- Digits: otp.Digits,
- Period: otp.Period,
- }); err != nil {
- return nil, err
- }
-
- if err := s.otpRepo.Update(ctx, otp); err != nil {
- return nil, fmt.Errorf("failed to update OTP: %w", err)
- }
-
- return otp, nil
-}
-
-// DeleteOTP deletes an OTP
-func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
- if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
- return fmt.Errorf("failed to delete OTP: %w", err)
- }
- return nil
-}
-
-// GenerateCode generates a TOTP code with enhanced logging and error handling
-func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
- start := time.Now()
-
- otp, err := s.otpRepo.FindByID(ctx, id, userID)
- if err != nil {
- log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
- return "", 0, fmt.Errorf("failed to get OTP: %w", err)
- }
-
- // Get current time step
- now := time.Now().Unix()
- timeStep := now / int64(otp.Period)
-
- // Generate code
- code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
- if err != nil {
- log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
- return "", 0, fmt.Errorf("failed to generate code: %w", err)
- }
-
- // Calculate remaining seconds
- remainingSeconds := otp.Period - int(now%int64(otp.Period))
-
- // Log successful generation (without actual code)
- log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
- id, userID, time.Since(start), remainingSeconds)
-
- return code, remainingSeconds, nil
-}
-
-// VerifyCode verifies a TOTP code with enhanced security and logging
-func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
- start := time.Now()
-
- // Basic input validation
- if len(code) == 0 {
- log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
- return false, fmt.Errorf("code is required")
- }
-
- otp, err := s.otpRepo.FindByID(ctx, id, userID)
- if err != nil {
- log.Printf("Failed to find OTP %s for user %s during verification: %v",
- id, userID, err)
- return false, fmt.Errorf("failed to get OTP: %w", err)
- }
-
- // Get current and adjacent time steps
- now := time.Now().Unix()
- timeSteps := []int64{
- (now - int64(otp.Period)) / int64(otp.Period),
- now / int64(otp.Period),
- (now + int64(otp.Period)) / int64(otp.Period),
- }
-
- // Check code against all time steps
- for _, ts := range timeSteps {
- expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
- if err != nil {
- log.Printf("Code generation failed for time step %d: %v", ts, err)
- continue
- }
- if expectedCode == code {
- // Log successful verification
- log.Printf("Code verified successfully for OTP %s (user %s) in %v",
- id, userID, time.Since(start))
- return true, nil
- }
- }
-
- // Log failed verification attempt
- log.Printf("Invalid code provided for OTP %s (user %s) in %v",
- id, userID, time.Since(start))
-
- return false, nil
-}
-
-// validateOTPInput validates OTP input with detailed error messages
-func (s *OTPService) validateOTPInput(input models.OTPParams) error {
- if input.Name == "" {
- return fmt.Errorf("name is required")
- }
-
- if len(input.Name) > 100 {
- return fmt.Errorf("name is too long (maximum 100 characters)")
- }
-
- if input.Secret == "" {
- return fmt.Errorf("secret is required")
- }
-
- if !isValidBase32(input.Secret) {
- return fmt.Errorf("invalid secret format: must be a valid base32 string")
- }
-
- // Secret length check (after base32 decoding)
- secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
- if len(secretBytes) < 10 {
- return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
- }
-
- if input.Algorithm != "" {
- if !isValidAlgorithm(input.Algorithm) {
- return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
- }
- }
-
- if input.Digits != 0 {
- if input.Digits < 6 || input.Digits > 8 {
- return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
- }
- }
-
- if input.Period != 0 {
- if input.Period < 30 || input.Period > 60 {
- return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
- }
- }
-
- return nil
-}
-
-// Helper functions
-
-func cleanSecret(secret string) string {
- // Remove spaces and convert to upper case
- secret = strings.TrimSpace(strings.ToUpper(secret))
- // Remove any padding characters
- return strings.TrimRight(secret, "=")
-}
-
-func isValidBase32(s string) bool {
- // Try to decode the secret
- _, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
- return err == nil
-}
-
-func isValidAlgorithm(algorithm string) bool {
- switch strings.ToUpper(algorithm) {
- case "SHA1", "SHA256", "SHA512":
- return true
- default:
- return false
- }
-}
-
-func getHasher(algorithm string, key []byte) (hash.Hash, error) {
- switch strings.ToUpper(algorithm) {
- case "SHA1":
- return hmac.New(sha1.New, key), nil
- case "SHA256":
- return hmac.New(sha256.New, key), nil
- case "SHA512":
- return hmac.New(sha512.New, key), nil
- default:
- return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
- }
-}
-
-func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
- // Decode secret
- secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
- if err != nil {
- return "", fmt.Errorf("invalid secret: %w", err)
- }
-
- // Get initialized HMAC hasher with secret
- hasher, err := getHasher(algorithm, secretBytes)
- if err != nil {
- return "", err
- }
-
- // Convert time step to bytes
- timeBytes := make([]byte, 8)
- binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
-
- // Calculate HMAC
- hasher.Write(timeBytes)
- hash := hasher.Sum(nil)
-
- // Get offset
- offset := hash[len(hash)-1] & 0xf
-
- // Generate 4-byte code
- code := binary.BigEndian.Uint32(hash[offset : offset+4])
- code = code & 0x7fffffff
-
- // Get the specified number of digits
- code = code % uint32(pow10(digits))
-
- // Format code with leading zeros
- return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
-}
-
-func pow10(n int) uint32 {
- result := uint32(1)
- for i := 0; i < n; i++ {
- result *= 10
- }
- return result
-}
diff --git a/utils/utils.go b/utils/utils.go
deleted file mode 100644
index efd6713..0000000
--- a/utils/utils.go
+++ /dev/null
@@ -1,105 +0,0 @@
-package utils
-
-import (
- "context"
- "crypto/aes"
- "crypto/cipher"
- "encoding/base64"
- "fmt"
- "net/http"
- "strings"
-
- "github.com/golang-jwt/jwt"
- "github.com/julienschmidt/httprouter"
- "github.com/spf13/viper"
-)
-
-// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
-func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
- // 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数
- return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
- // 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递
- h(w, r)
- }
-}
-
-func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- authHeader := r.Header.Get("Authorization")
- if authHeader == "" {
- http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
- return
- }
-
- tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
- secret := viper.GetString("auth.secret")
-
- token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method")
- }
- return []byte(secret), nil
- })
-
- if err != nil || !token.Valid {
- http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
- return
- }
-
- claims, ok := token.Claims.(jwt.MapClaims)
- if !ok {
- http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
- return
- }
-
- type contextKey string
- // 将 openid 存入上下文
- ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
- next.ServeHTTP(w, r.WithContext(ctx))
- }
-}
-
-// AesDecrypt 函数用于AES解密
-func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
- //Base64解码
- keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
- if err != nil {
- return nil, err
- }
- ivBytes, err := base64.StdEncoding.DecodeString(iv)
- if err != nil {
- return nil, err
- }
- cryptData, err := base64.StdEncoding.DecodeString(encryptedData)
- if err != nil {
- return nil, err
- }
- origData := make([]byte, len(cryptData))
- //AES
- block, err := aes.NewCipher(keyBytes)
- if err != nil {
- return nil, err
- }
- //CBC
- mode := cipher.NewCBCDecrypter(block, ivBytes)
- //解密
- mode.CryptBlocks(origData, cryptData)
- //去除填充位
- origData = PKCS7UnPadding(origData)
- return origData, nil
-}
-
-// PKCS7UnPadding 函数用于去除PKCS7填充的密文
-func PKCS7UnPadding(plantText []byte) []byte {
- // 获取密文的长度
- length := len(plantText)
- // 如果密文长度大于0
- if length > 0 {
- // 获取最后一个字节的值,即填充的位数
- unPadding := int(plantText[length-1])
- // 返回去除填充后的密文
- return plantText[:(length - unPadding)]
- }
- // 如果密文长度为0,则返回原密文
- return plantText
-}
diff --git a/validator/validator.go b/validator/validator.go
deleted file mode 100644
index 13cff58..0000000
--- a/validator/validator.go
+++ /dev/null
@@ -1,316 +0,0 @@
-package validator
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "reflect"
- "regexp"
- "strings"
-
- "github.com/go-playground/validator/v10"
-)
-
-var (
- validate *validator.Validate
-
- // 自定义验证规则
- customValidations = map[string]validator.Func{
- "otpsecret": validateOTPSecret,
- "password": validatePassword,
- "issuer": validateIssuer,
- "otpauth_uri": validateOTPAuthURI,
- "no_xss": validateNoXSS,
- }
-
- // 常见的弱密码列表(实际使用时应该使用更完整的列表)
- commonPasswords = map[string]bool{
- "password123": true,
- "12345678": true,
- "qwerty123": true,
- "admin123": true,
- "letmein": true,
- "welcome": true,
- "password": true,
- "admin": true,
- }
-
- // 预编译的XSS检测正则表达式
- xssPatterns = []*regexp.Regexp{
- regexp.MustCompile(`(?i)`),
- regexp.MustCompile(`(?i)javascript:`),
- regexp.MustCompile(`(?i)data:text/html`),
- regexp.MustCompile(`(?i)on\w+\s*=`),
- regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
- regexp.MustCompile(`(?i)<\s*iframe`),
- regexp.MustCompile(`(?i)<\s*object`),
- regexp.MustCompile(`(?i)<\s*embed`),
- regexp.MustCompile(`(?i)<\s*style`),
- regexp.MustCompile(`(?i)<\s*form`),
- regexp.MustCompile(`(?i)<\s*applet`),
- regexp.MustCompile(`(?i)<\s*meta`),
- regexp.MustCompile(`(?i)expression\s*\(`),
- regexp.MustCompile(`(?i)url\s*\(`),
- }
-
- // 预编译的正则表达式
- base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
- issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
- otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
- upperRegex = regexp.MustCompile(`[A-Z]`)
- lowerRegex = regexp.MustCompile(`[a-z]`)
- numberRegex = regexp.MustCompile(`[0-9]`)
- specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
-)
-
-func init() {
- validate = validator.New()
-
- // 注册自定义验证规则
- for tag, fn := range customValidations {
- if err := validate.RegisterValidation(tag, fn); err != nil {
- panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
- }
- }
-
- // 使用json tag作为字段名
- validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
- name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
- if name == "-" {
- return ""
- }
- return name
- })
-}
-
-// ValidateRequest validates a request body against a struct
-func ValidateRequest(r *http.Request, v interface{}) error {
- if err := json.NewDecoder(r.Body).Decode(v); err != nil {
- return fmt.Errorf("invalid request body: %w", err)
- }
-
- if err := validate.Struct(v); err != nil {
- if validationErrors, ok := err.(validator.ValidationErrors); ok {
- return NewValidationError(validationErrors)
- }
- return fmt.Errorf("validation error: %w", err)
- }
-
- return nil
-}
-
-// ValidationError represents a validation error
-type ValidationError struct {
- Fields map[string]string `json:"fields"`
-}
-
-// Error implements the error interface
-func (e *ValidationError) Error() string {
- var errors []string
- for field, msg := range e.Fields {
- errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
- }
- return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
-}
-
-// NewValidationError creates a new ValidationError from validator.ValidationErrors
-func NewValidationError(errors validator.ValidationErrors) *ValidationError {
- fields := make(map[string]string)
- for _, err := range errors {
- fields[err.Field()] = getErrorMessage(err)
- }
- return &ValidationError{Fields: fields}
-}
-
-// getErrorMessage returns a human-readable error message for a validation error
-func getErrorMessage(err validator.FieldError) string {
- switch err.Tag() {
- case "required":
- return "此字段为必填项"
- case "email":
- return "请输入有效的电子邮件地址"
- case "min":
- if err.Type().Kind() == reflect.String {
- return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
- }
- return fmt.Sprintf("必须大于或等于 %s", err.Param())
- case "max":
- if err.Type().Kind() == reflect.String {
- return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
- }
- return fmt.Sprintf("必须小于或等于 %s", err.Param())
- case "len":
- return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
- case "oneof":
- return fmt.Sprintf("必须是以下值之一: %s", err.Param())
- case "otpsecret":
- return "OTP密钥格式无效,必须是有效的Base32编码"
- case "password":
- return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符"
- case "issuer":
- return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
- case "otpauth_uri":
- return "OTP认证URI格式无效"
- case "no_xss":
- return "输入包含潜在的不安全内容"
- case "numeric":
- return "必须是数字"
- default:
- return fmt.Sprintf("验证失败: %s", err.Tag())
- }
-}
-
-// Custom validation functions
-
-// validateOTPSecret validates an OTP secret
-func validateOTPSecret(fl validator.FieldLevel) bool {
- secret := fl.Field().String()
-
- if secret == "" {
- return false
- }
-
- // OTP secret should be base32 encoded
- if !base32Regex.MatchString(secret) {
- return false
- }
-
- // Check length (typical OTP secrets are 16-64 characters)
- validLength := len(secret) >= 16 && len(secret) <= 128
-
- return validLength
-}
-
-// validatePassword validates a password
-func validatePassword(fl validator.FieldLevel) bool {
- password := fl.Field().String()
-
- // At least 10 characters long
- if len(password) < 10 {
- return false
- }
-
- // Check if it's a common password
- if commonPasswords[strings.ToLower(password)] {
- return false
- }
-
- // Check character types
- hasUpper := upperRegex.MatchString(password)
- hasLower := lowerRegex.MatchString(password)
- hasNumber := numberRegex.MatchString(password)
- hasSpecial := specialRegex.MatchString(password)
-
- // Ensure password has enough complexity
- complexity := 0
- if hasUpper {
- complexity++
- }
- if hasLower {
- complexity++
- }
- if hasNumber {
- complexity++
- }
- if hasSpecial {
- complexity++
- }
-
- return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
-}
-
-// validateIssuer validates an issuer name
-func validateIssuer(fl validator.FieldLevel) bool {
- issuer := fl.Field().String()
-
- if issuer == "" {
- return false
- }
-
- // Issuer should not contain special characters that could cause problems in URLs
- if !issuerRegex.MatchString(issuer) {
- return false
- }
-
- // Check length
- validLength := len(issuer) >= 1 && len(issuer) <= 100
-
- return validLength
-}
-
-// validateOTPAuthURI validates an otpauth URI
-func validateOTPAuthURI(fl validator.FieldLevel) bool {
- uri := fl.Field().String()
-
- if uri == "" {
- return false
- }
-
- // Basic format check for otpauth URI
- // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
- return otpauthRegex.MatchString(uri)
-}
-
-// validateNoXSS checks if a string contains potential XSS payloads
-func validateNoXSS(fl validator.FieldLevel) bool {
- value := fl.Field().String()
-
- // 检查基本的HTML编码
- if strings.Contains(value, "") ||
- strings.Contains(value, "<") ||
- strings.Contains(value, ">") {
- return false
- }
-
- // 检查十六进制编码
- if strings.Contains(strings.ToLower(value), "\\x3c") || // <
- strings.Contains(strings.ToLower(value), "\\x3e") { // >
- return false
- }
-
- // 检查Unicode编码
- if strings.Contains(strings.ToLower(value), "\\u003c") || // <
- strings.Contains(strings.ToLower(value), "\\u003e") { // >
- return false
- }
-
- // 使用预编译的正则表达式检查XSS模式
- for _, pattern := range xssPatterns {
- if pattern.MatchString(value) {
- return false
- }
- }
-
- return true
-}
-
-// Request validation structs
-
-// LoginRequest represents a login request
-type LoginRequest struct {
- Code string `json:"code" validate:"required,len=6|len=8,numeric"`
-}
-
-// CreateOTPRequest represents a request to create an OTP
-type CreateOTPRequest struct {
- Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
- Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
- Secret string `json:"secret" validate:"required,otpsecret"`
- Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
- Digits int `json:"digits" validate:"required,oneof=6 8"`
- Period int `json:"period" validate:"required,oneof=30 60"`
-}
-
-// UpdateOTPRequest represents a request to update an OTP
-type UpdateOTPRequest struct {
- Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
- Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
- Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
- Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
- Period int `json:"period" validate:"omitempty,oneof=30 60"`
-}
-
-// VerifyOTPRequest represents a request to verify an OTP code
-type VerifyOTPRequest struct {
- Code string `json:"code" validate:"required,len=6|len=8,numeric"`
-}