From 01b8951dd5e2055404fb50e2fcf4508e7f6c9c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CxHuPo=E2=80=9D?= <7513325+vrocwang@users.noreply.github.com> Date: Mon, 9 Jun 2025 11:20:07 +0800 Subject: [PATCH] add branch v1 --- api/response.go | 149 ---- api/validator.go | 152 ---- api_server.go | 202 ++++++ auth/middleware.go | 105 +++ cache/cache.go | 206 ------ cmd/root.go | 144 ---- config.yaml | 53 +- config/config.go | 178 +++-- database/db.go | 212 ------ database/init/otp.sql | 26 - database/init/users.sql | 6 - database/migration.go | 160 ----- docs/swagger.go | 663 ------------------ go.mod | 50 -- go.sum | 124 ---- handlers/auth_handler.go | 156 ----- handlers/otp_handler.go | 114 --- init/postgresql/init.sql | 51 ++ init/sqlite3/init.sql | 50 ++ logger/logger.go | 204 ------ main.go | 7 - metrics/metrics.go | 193 ----- middleware/middleware.go | 353 ---------- miniprogram-example/app.js | 50 -- miniprogram-example/app.json | 23 - miniprogram-example/app.wxss | 238 ------- miniprogram-example/pages/login/login.js | 48 -- miniprogram-example/pages/login/login.json | 3 - miniprogram-example/pages/login/login.wxml | 30 - miniprogram-example/pages/login/login.wxss | 97 --- miniprogram-example/pages/otp-add/index.js | 169 ----- miniprogram-example/pages/otp-add/index.json | 3 - miniprogram-example/pages/otp-add/index.wxml | 119 ---- miniprogram-example/pages/otp-add/index.wxss | 176 ----- miniprogram-example/pages/otp-list/index.js | 213 ------ miniprogram-example/pages/otp-list/index.json | 3 - miniprogram-example/pages/otp-list/index.wxml | 59 -- miniprogram-example/pages/otp-list/index.wxss | 201 ------ miniprogram-example/project.config.json | 47 -- .../project.private.config.json | 23 - miniprogram-example/services/auth.js | 84 --- miniprogram-example/services/otp.js | 119 ---- miniprogram-example/utils/request.js | 58 -- models/otp.go | 66 -- models/user.go | 114 --- otp_api.go | 417 +++++++++++ otp_api_test.go | 111 +++ security/security.go | 332 --------- server/server.go | 190 ----- services/auth.go | 230 ------ services/otp.go | 358 ---------- utils/utils.go | 105 --- validator/validator.go | 316 --------- 53 files changed, 1079 insertions(+), 6481 deletions(-) delete mode 100644 api/response.go delete mode 100644 api/validator.go create mode 100644 api_server.go create mode 100644 auth/middleware.go delete mode 100644 cache/cache.go delete mode 100644 cmd/root.go delete mode 100644 database/db.go delete mode 100644 database/init/otp.sql delete mode 100644 database/init/users.sql delete mode 100644 database/migration.go delete mode 100644 docs/swagger.go delete mode 100644 go.mod delete mode 100644 go.sum delete mode 100644 handlers/auth_handler.go delete mode 100644 handlers/otp_handler.go create mode 100644 init/postgresql/init.sql create mode 100644 init/sqlite3/init.sql delete mode 100644 logger/logger.go delete mode 100644 main.go delete mode 100644 metrics/metrics.go delete mode 100644 middleware/middleware.go delete mode 100644 miniprogram-example/app.js delete mode 100644 miniprogram-example/app.json delete mode 100644 miniprogram-example/app.wxss delete mode 100644 miniprogram-example/pages/login/login.js delete mode 100644 miniprogram-example/pages/login/login.json delete mode 100644 miniprogram-example/pages/login/login.wxml delete mode 100644 miniprogram-example/pages/login/login.wxss delete mode 100644 miniprogram-example/pages/otp-add/index.js delete mode 100644 miniprogram-example/pages/otp-add/index.json delete mode 100644 miniprogram-example/pages/otp-add/index.wxml delete mode 100644 miniprogram-example/pages/otp-add/index.wxss delete mode 100644 miniprogram-example/pages/otp-list/index.js delete mode 100644 miniprogram-example/pages/otp-list/index.json delete mode 100644 miniprogram-example/pages/otp-list/index.wxml delete mode 100644 miniprogram-example/pages/otp-list/index.wxss delete mode 100644 miniprogram-example/project.config.json delete mode 100644 miniprogram-example/project.private.config.json delete mode 100644 miniprogram-example/services/auth.js delete mode 100644 miniprogram-example/services/otp.js delete mode 100644 miniprogram-example/utils/request.js delete mode 100644 models/otp.go delete mode 100644 models/user.go create mode 100644 otp_api.go create mode 100644 otp_api_test.go delete mode 100644 security/security.go delete mode 100644 server/server.go delete mode 100644 services/auth.go delete mode 100644 services/otp.go delete mode 100644 utils/utils.go delete mode 100644 validator/validator.go diff --git a/api/response.go b/api/response.go deleted file mode 100644 index 8e24daf..0000000 --- a/api/response.go +++ /dev/null @@ -1,149 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" -) - -// Common error codes -const ( - CodeSuccess = 0 - CodeInvalidParams = 400 - CodeUnauthorized = 401 - CodeForbidden = 403 - CodeNotFound = 404 - CodeInternalError = 500 - CodeServiceUnavail = 503 -) - -// Error represents an API error -type Error struct { - Code int `json:"code"` - Message string `json:"message"` -} - -// Error implements the error interface -func (e *Error) Error() string { - return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message) -} - -// NewError creates a new API error -func NewError(code int, message string) *Error { - return &Error{ - Code: code, - Message: message, - } -} - -// Response represents a standard API response -type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -// ResponseWriter wraps common response writing functions -type ResponseWriter struct { - http.ResponseWriter -} - -// NewResponseWriter creates a new ResponseWriter -func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { - return &ResponseWriter{w} -} - -// WriteJSON writes a JSON response -func (w *ResponseWriter) WriteJSON(code int, data interface{}) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - return json.NewEncoder(w).Encode(data) -} - -// WriteSuccess writes a success response -func (w *ResponseWriter) WriteSuccess(data interface{}) error { - return w.WriteJSON(http.StatusOK, Response{ - Code: CodeSuccess, - Message: "success", - Data: data, - }) -} - -// WriteError writes an error response -func (w *ResponseWriter) WriteError(err error) error { - var apiErr *Error - if errors.As(err, &apiErr) { - return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{ - Code: apiErr.Code, - Message: apiErr.Message, - }) - } - - // Handle unknown errors - return w.WriteJSON(http.StatusInternalServerError, Response{ - Code: CodeInternalError, - Message: "Internal Server Error", - }) -} - -// WriteErrorWithCode writes an error response with a specific code -func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error { - return w.WriteJSON(getHTTPStatus(code), Response{ - Code: code, - Message: message, - }) -} - -// getHTTPStatus maps API error codes to HTTP status codes -func getHTTPStatus(code int) int { - switch code { - case CodeSuccess: - return http.StatusOK - case CodeInvalidParams: - return http.StatusBadRequest - case CodeUnauthorized: - return http.StatusUnauthorized - case CodeForbidden: - return http.StatusForbidden - case CodeNotFound: - return http.StatusNotFound - case CodeServiceUnavail: - return http.StatusServiceUnavailable - default: - return http.StatusInternalServerError - } -} - -// Common errors -var ( - ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters") - ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized") - ErrForbidden = NewError(CodeForbidden, "Forbidden") - ErrNotFound = NewError(CodeNotFound, "Resource not found") - ErrInternalError = NewError(CodeInternalError, "Internal server error") - ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable") -) - -// ValidationError creates an error for invalid parameters -func ValidationError(message string) *Error { - return NewError(CodeInvalidParams, message) -} - -// NotFoundError creates an error for not found resources -func NotFoundError(resource string) *Error { - return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource)) -} - -// ForbiddenError creates an error for forbidden actions -func ForbiddenError(message string) *Error { - return NewError(CodeForbidden, message) -} - -// InternalError creates an error for internal server errors -func InternalError(err error) *Error { - if err == nil { - return ErrInternalError - } - return NewError(CodeInternalError, err.Error()) -} diff --git a/api/validator.go b/api/validator.go deleted file mode 100644 index a18a552..0000000 --- a/api/validator.go +++ /dev/null @@ -1,152 +0,0 @@ -package api - -import ( - "regexp" - "strings" - - "github.com/go-playground/validator/v10" -) - -// Validate is a global validator instance -var Validate = validator.New() - -// RegisterCustomValidations registers custom validation functions -func RegisterCustomValidations() { - // Register custom validation for issuer - Validate.RegisterValidation("issuer", validateIssuer) - - // Register custom validation for XSS prevention - Validate.RegisterValidation("no_xss", validateNoXSS) - - // Register custom validation for OTP secret - Validate.RegisterValidation("otpsecret", validateOTPSecret) -} - -// validateOTPSecret validates that the OTP secret is in valid base32 format -func validateOTPSecret(fl validator.FieldLevel) bool { - secret := fl.Field().String() - - // Check if the secret is not empty - if secret == "" { - return false - } - - // Check if the secret is in base32 format (A-Z, 2-7) - base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`) - if !base32Regex.MatchString(secret) { - return false - } - - // Check if the length is valid (must be at least 16 characters) - if len(secret) < 16 || len(secret) > 128 { - return false - } - - return true -} - -// validateIssuer validates that the issuer field contains only allowed characters -func validateIssuer(fl validator.FieldLevel) bool { - issuer := fl.Field().String() - - // Empty issuer is valid (since it's optional) - if issuer == "" { - return true - } - - // Allow alphanumeric characters, spaces, and common punctuation - issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api - -import ( - "regexp" - "strings" - - "github.com/go-playground/validator/v10" -) - -// Validate is a global validator instance -var Validate = validator.New() - -// RegisterCustomValidations registers custom validation functions -func RegisterCustomValidations() { - // Register custom validation for issuer - Validate.RegisterValidation("issuer", validateIssuer) - - // Register custom validation for XSS prevention - Validate.RegisterValidation("no_xss", validateNoXSS) - - // Register custom validation for OTP secret - Validate.RegisterValidation("otpsecret", validateOTPSecret) -} - -// validateOTPSecret validates that the OTP secret is in valid base32 format -func validateOTPSecret(fl validator.FieldLevel) bool { - secret := fl.Field().String() - - // Check if the secret is not empty - if secret == "" { - return false - } - - // Check if the secret is in base32 format (A-Z, 2-7) - base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`) - if !base32Regex.MatchString(secret) { - return false - } - - // Check if the length is valid (must be at least 16 characters) - if len(secret) < 16 || len(secret) > 128 { - return false - } - - return true -} - -) - if !issuerRegex.MatchString(issuer) { - return false - } - - // Check length - if len(issuer) > 100 { - return false - } - - return true -} - -// validateNoXSS validates that the field doesn't contain potential XSS payloads -func validateNoXSS(fl validator.FieldLevel) bool { - value := fl.Field().String() - - // Check for HTML encoding - if strings.Contains(value, "&#") || - strings.Contains(value, "<") || - strings.Contains(value, ">") { - return false - } - - // Check for common XSS patterns - suspiciousPatterns := []*regexp.Regexp{ - regexp.MustCompile(`(?i)]*>.*?`), - regexp.MustCompile(`(?i)javascript:`), - regexp.MustCompile(`(?i)data:text/html`), - regexp.MustCompile(`(?i)on\w+\s*=`), - regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), - regexp.MustCompile(`(?i)<\s*iframe`), - regexp.MustCompile(`(?i)<\s*object`), - regexp.MustCompile(`(?i)<\s*embed`), - regexp.MustCompile(`(?i)<\s*style`), - regexp.MustCompile(`(?i)<\s*form`), - regexp.MustCompile(`(?i)<\s*applet`), - regexp.MustCompile(`(?i)<\s*meta`), - } - - for _, pattern := range suspiciousPatterns { - if pattern.MatchString(value) { - return false - } - } - - return true -} \ No newline at end of file diff --git a/api_server.go b/api_server.go new file mode 100644 index 0000000..26a53bb --- /dev/null +++ b/api_server.go @@ -0,0 +1,202 @@ +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + + "auth" + "config" + + "github.com/gorilla/mux" +) + +// Server represents the API server +type Server struct { + config *config.Config + router *mux.Router +} + +type WechatLoginResponse struct { + OpenID string `json:"openid"` + SessionKey string `json:"session_key"` + UnionID string `json:"unionid,omitempty"` + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` +} + +// NewServer creates a new instance of Server +func NewServer(cfg *config.Config) *Server { + s := &Server{ + config: cfg, + router: mux.NewRouter(), + } + s.setupRoutes() + return s +} + +// setupRoutes configures all the routes for the server +func (s *Server) setupRoutes() { + // 公开路由 + s.router.HandleFunc("/auth/login", s.WechatLoginHandler) + + // 受保护路由(需要JWT) + authRouter := s.router.PathPrefix("").Subrouter() + authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey)) + authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST") + authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST") + + // 添加CORS中间件 + s.router.Use(s.corsMiddleware) +} + +// corsMiddleware handles CORS +func (s *Server) corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // Check if the origin is allowed + allowed := false + for _, allowedOrigin := range s.config.CORS.AllowedOrigins { + if origin == allowedOrigin { + allowed = true + break + } + } + + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", + joinStrings(s.config.CORS.AllowedMethods)) + w.Header().Set("Access-Control-Allow-Headers", + joinStrings(s.config.CORS.AllowedHeaders)) + } + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// joinStrings joins string slice with commas +func joinStrings(slice []string) string { + result := "" + for i, s := range slice { + if i > 0 { + result += ", " + } + result += s + } + return result +} + +func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 读取请求体 + body, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + var req struct { + Code string `json:"code"` + } + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + // 向微信服务器请求session_key + url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", + s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code) + + resp, err := http.Get(url) + if err != nil { + http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable) + return + } + defer resp.Body.Close() + + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Wechat service error", http.StatusInternalServerError) + return + } + + var wechatResp WechatLoginResponse + if err := json.Unmarshal(body, &wechatResp); err != nil { + http.Error(w, "Wechat response parse error", http.StatusInternalServerError) + return + } + + if wechatResp.ErrCode != 0 { + http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized) + return + } + + // 生成JWT token + token, err := s.generateSessionToken(wechatResp.OpenID) + if err != nil { + http.Error(w, "Failed to generate token", http.StatusInternalServerError) + return + } + + // 返回响应 + response := map[string]interface{}{ + "token": token, + "openid": wechatResp.OpenID, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } +} + +func (s *Server) generateSessionToken(openid string) (string, error) { + return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry) +} + +// Start starts the HTTP server +func (s *Server) Start() error { + addr := fmt.Sprintf(":%d", s.config.Server.Port) + log.Printf("Starting server on %s", addr) + + srv := &http.Server{ + Handler: s.router, + Addr: addr, + WriteTimeout: s.config.Server.Timeout, + ReadTimeout: s.config.Server.Timeout, + } + + return srv.ListenAndServe() +} + +func main() { + // 加载配置 + cfg, err := config.LoadConfig("config") + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // 初始化数据库连接 + log.Println("Initializing database connection...") + if err := InitDB(cfg.Database); err != nil { + log.Fatalf("Failed to initialize database: %v", err) + } + log.Println("Database connection established successfully") + + // 创建并启动服务器 + server := NewServer(cfg) + if err := server.Start(); err != nil { + log.Fatalf("Server failed to start: %v", err) + } +} diff --git a/auth/middleware.go b/auth/middleware.go new file mode 100644 index 0000000..a5f8429 --- /dev/null +++ b/auth/middleware.go @@ -0,0 +1,105 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type contextKey string + +const ( + UserIDContextKey contextKey = "userID" +) + +// Claims represents the JWT claims +type Claims struct { + UserID string `json:"user_id"` + jwt.RegisteredClaims +} + +// GenerateToken generates a new JWT token +func GenerateToken(userID string, signingKey string, expiry time.Duration) (string, error) { + claims := Claims{ + UserID: userID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(signingKey)) +} + +// NewAuthMiddleware creates a new authentication middleware +func NewAuthMiddleware(signingKey string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Extract token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "Authorization header required", http.StatusUnauthorized) + return + } + + // Check if the header has the correct format + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + http.Error(w, "Invalid authorization header format", http.StatusUnauthorized) + return + } + + tokenString := parts[1] + + // Parse and validate the token + claims := &Claims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + // Validate signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(signingKey), nil + }) + + if err != nil { + if err == jwt.ErrSignatureInvalid { + http.Error(w, "Invalid token signature", http.StatusUnauthorized) + return + } + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + if !token.Valid { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + // Add user ID to request context + ctx := context.WithValue(r.Context(), UserIDContextKey, claims.UserID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// GetUserIDFromContext extracts the user ID from the request context +func GetUserIDFromContext(ctx context.Context) (string, error) { + userID, ok := ctx.Value(UserIDContextKey).(string) + if !ok { + return "", fmt.Errorf("user ID not found in context") + } + return userID, nil +} + +// RequireAuth is a middleware that ensures a valid JWT token is present +func RequireAuth(signingKey string, next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authMiddleware := NewAuthMiddleware(signingKey) + authMiddleware(http.HandlerFunc(next)).ServeHTTP(w, r) + } +} diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index 08d3041..0000000 --- a/cache/cache.go +++ /dev/null @@ -1,206 +0,0 @@ -package cache - -import ( - "context" - "encoding/json" - "fmt" - "sync" - "time" -) - -// Item represents a cache item -type Item struct { - Value []byte - Expiration int64 -} - -// Expired returns true if the item has expired -func (item Item) Expired() bool { - if item.Expiration == 0 { - return false - } - return time.Now().UnixNano() > item.Expiration -} - -// Cache represents an in-memory cache -type Cache struct { - items map[string]Item - mu sync.RWMutex - defaultExpiration time.Duration - cleanupInterval time.Duration - stopCleanup chan bool -} - -// New creates a new cache with the given default expiration and cleanup interval -func New(defaultExpiration, cleanupInterval time.Duration) *Cache { - cache := &Cache{ - items: make(map[string]Item), - defaultExpiration: defaultExpiration, - cleanupInterval: cleanupInterval, - stopCleanup: make(chan bool), - } - - // Start cleanup goroutine if cleanup interval > 0 - if cleanupInterval > 0 { - go cache.startCleanup() - } - - return cache -} - -// startCleanup starts the cleanup process -func (c *Cache) startCleanup() { - ticker := time.NewTicker(c.cleanupInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - c.DeleteExpired() - case <-c.stopCleanup: - return - } - } -} - -// Set adds an item to the cache with the given key and expiration -func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error { - // Convert value to bytes - var valueBytes []byte - var err error - - switch v := value.(type) { - case []byte: - valueBytes = v - case string: - valueBytes = []byte(v) - default: - valueBytes, err = json.Marshal(value) - if err != nil { - return fmt.Errorf("failed to marshal value: %w", err) - } - } - - // Calculate expiration - var exp int64 - if expiration == 0 { - if c.defaultExpiration > 0 { - exp = time.Now().Add(c.defaultExpiration).UnixNano() - } - } else if expiration > 0 { - exp = time.Now().Add(expiration).UnixNano() - } - - c.mu.Lock() - c.items[key] = Item{ - Value: valueBytes, - Expiration: exp, - } - c.mu.Unlock() - - return nil -} - -// Get gets an item from the cache -func (c *Cache) Get(key string, value interface{}) (bool, error) { - c.mu.RLock() - item, found := c.items[key] - c.mu.RUnlock() - - if !found { - return false, nil - } - - // Check if item has expired - if item.Expired() { - c.mu.Lock() - delete(c.items, key) - c.mu.Unlock() - return false, nil - } - - // Unmarshal value - switch v := value.(type) { - case *[]byte: - *v = item.Value - case *string: - *v = string(item.Value) - default: - if err := json.Unmarshal(item.Value, value); err != nil { - return true, fmt.Errorf("failed to unmarshal value: %w", err) - } - } - - return true, nil -} - -// Delete deletes an item from the cache -func (c *Cache) Delete(key string) { - c.mu.Lock() - delete(c.items, key) - c.mu.Unlock() -} - -// DeleteExpired deletes all expired items from the cache -func (c *Cache) DeleteExpired() { - now := time.Now().UnixNano() - - c.mu.Lock() - for k, v := range c.items { - if v.Expiration > 0 && now > v.Expiration { - delete(c.items, k) - } - } - c.mu.Unlock() -} - -// Clear deletes all items from the cache -func (c *Cache) Clear() { - c.mu.Lock() - c.items = make(map[string]Item) - c.mu.Unlock() -} - -// Close stops the cleanup goroutine -func (c *Cache) Close() { - if c.cleanupInterval > 0 { - c.stopCleanup <- true - } -} - -// CacheService provides caching functionality -type CacheService struct { - cache *Cache -} - -// NewCacheService creates a new CacheService -func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService { - return &CacheService{ - cache: New(defaultExpiration, cleanupInterval), - } -} - -// Set adds an item to the cache -func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error { - return s.cache.Set(key, value, expiration) -} - -// Get gets an item from the cache -func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) { - return s.cache.Get(key, value) -} - -// Delete deletes an item from the cache -func (s *CacheService) Delete(ctx context.Context, key string) { - s.cache.Delete(key) -} - -// Clear deletes all items from the cache -func (s *CacheService) Clear(ctx context.Context) { - s.cache.Clear() -} - -// Close closes the cache -func (s *CacheService) Close() { - s.cache.Close() -} diff --git a/cmd/root.go b/cmd/root.go deleted file mode 100644 index 8821b91..0000000 --- a/cmd/root.go +++ /dev/null @@ -1,144 +0,0 @@ -package cmd - -import ( - "context" - "fmt" - "log" - "net/http" - "os" - "os/signal" - "syscall" - - "github.com/spf13/viper" - - "otpm/api" - "otpm/config" - "otpm/database" - "otpm/handlers" - "otpm/models" - "otpm/server" - "otpm/services" -) - -func init() { - // Set config file with multi-environment support - viper.SetConfigName("config") - viper.SetConfigType("yaml") - viper.AddConfigPath(".") - - // Set environment specific config (e.g. config.production.yaml) - env := os.Getenv("OTPM_ENV") - if env != "" { - viper.SetConfigName(fmt.Sprintf("config.%s", env)) - } - - // Set default values - viper.SetDefault("server.port", "8080") - viper.SetDefault("server.timeout.read", "15s") - viper.SetDefault("server.timeout.write", "15s") - viper.SetDefault("server.timeout.idle", "60s") - viper.SetDefault("database.max_open_conns", 25) - viper.SetDefault("database.max_idle_conns", 5) - viper.SetDefault("database.conn_max_lifetime", "5m") - - // Set environment variable prefix - viper.SetEnvPrefix("OTPM") - viper.AutomaticEnv() - - // Bind environment variables - viper.BindEnv("database.url", "OTPM_DB_URL") - viper.BindEnv("database.password", "OTPM_DB_PASSWORD") -} - -// Execute is the entry point for the application -func Execute() error { - // Load configuration - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - // Create context with cancellation - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Setup signal handling - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - go func() { - sig := <-sigChan - log.Printf("Received signal: %v", sig) - cancel() - }() - - // Initialize database - db, err := database.New(&cfg.Database) - if err != nil { - return fmt.Errorf("failed to initialize database: %w", err) - } - defer db.Close() - - // Run database migrations - if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil { - return fmt.Errorf("failed to run migrations: %w", err) - } - - // Initialize repositories - userRepo := models.NewUserRepository(db.DB) - otpRepo := models.NewOTPRepository(db.DB) - - // Initialize services - authService := services.NewAuthService(cfg, userRepo) - otpService := services.NewOTPService(otpRepo) - - // Register custom validations - api.RegisterCustomValidations() - - // Initialize handlers - authHandler := handlers.NewAuthHandler(authService) - otpHandler := handlers.NewOTPHandler(otpService) - - // Create and configure server - srv := server.New(cfg) - - // Register health check endpoint - srv.RegisterHealthCheck() - - // Register public routes with type conversion - authRoutes := make(map[string]http.Handler) - for path, handler := range authHandler.Routes() { - authRoutes[path] = http.HandlerFunc(handler) - } - srv.RegisterRoutes(authRoutes) - - // Register authenticated routes with type conversion - otpRoutes := make(map[string]http.Handler) - for path, handler := range otpHandler.Routes() { - otpRoutes[path] = http.HandlerFunc(handler) - } - srv.RegisterAuthRoutes(otpRoutes) - - // Start server in goroutine - serverErr := make(chan error, 1) - go func() { - log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port) - if err := srv.Start(); err != nil { - serverErr <- fmt.Errorf("server error: %w", err) - } - }() - - // Wait for shutdown signal or server error - select { - case err := <-serverErr: - return err - case <-ctx.Done(): - // Graceful shutdown with timeout - log.Println("Shutting down server...") - if err := srv.Shutdown(); err != nil { - return fmt.Errorf("server shutdown error: %w", err) - } - log.Println("Server stopped gracefully") - } - - return nil -} diff --git a/config.yaml b/config.yaml index da2013f..ead2412 100644 --- a/config.yaml +++ b/config.yaml @@ -1,23 +1,44 @@ +# Server Configuration server: port: 8080 - read_timeout: 15s - write_timeout: 15s - shutdown_timeout: 5s + timeout: 30s +# Database Configuration database: - driver: sqlite3 - dsn: otpm.sqlite - max_open_conns: 25 - max_idle_conns: 25 - max_lifetime: 5m - skip_migration: false + driver: "sqlite3" # or "postgres" + sqlite: + path: "./data.db" + postgres: + host: "localhost" + port: 5432 + user: "postgres" + password: "password" + dbname: "otpdb" + sslmode: "disable" -jwt: - secret: "your-jwt-secret-key-change-this-in-production" - expire_delta: 24h - refresh_delta: 168h - signing_method: HS256 +# Security Configuration +security: + encryption_key: "your-32-byte-encryption-key-here" + jwt_signing_key: "your-jwt-signing-key-here" + token_expiry: 24h + refresh_token_expiry: 168h # 7 days +# WeChat Configuration wechat: - app_id: "your-wechat-app-id" - app_secret: "your-wechat-app-secret" \ No newline at end of file + app_id: "YOUR_APPID" + app_secret: "YOUR_APPSECRET" + +# CORS Configuration +cors: + allowed_origins: + - "http://localhost:8080" + - "https://yourdomain.com" + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + - "OPTIONS" + allowed_headers: + - "Authorization" + - "Content-Type" \ No newline at end of file diff --git a/config/config.go b/config/config.go index 4cad24d..1bc9d36 100644 --- a/config/config.go +++ b/config/config.go @@ -7,122 +7,156 @@ import ( "github.com/spf13/viper" ) -// Config holds all configuration for the application +// Config holds all configuration for our application type Config struct { - Server ServerConfig `mapstructure:"server"` - Database DatabaseConfig `mapstructure:"database"` - JWT JWTConfig `mapstructure:"jwt"` - WeChat WeChatConfig `mapstructure:"wechat"` + Server ServerConfig + Database DatabaseConfig + Security SecurityConfig + CORS CORSConfig + Wechat WechatConfig } // ServerConfig holds all server related configuration type ServerConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - ReadTimeout time.Duration `mapstructure:"read_timeout"` - WriteTimeout time.Duration `mapstructure:"write_timeout"` - ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"` - Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout + Port int + Timeout time.Duration } // DatabaseConfig holds all database related configuration type DatabaseConfig struct { - Driver string `mapstructure:"driver"` - DSN string `mapstructure:"dsn"` - MaxOpenConns int `mapstructure:"max_open_conns"` - MaxIdleConns int `mapstructure:"max_idle_conns"` - MaxLifetime time.Duration `mapstructure:"max_lifetime"` - SkipMigration bool `mapstructure:"skip_migration"` + Driver string + SQLite SQLiteConfig + Postgres PostgresConfig } -// JWTConfig holds all JWT related configuration -type JWTConfig struct { - Secret string `mapstructure:"secret"` - ExpireDelta time.Duration `mapstructure:"expire_delta"` - RefreshDelta time.Duration `mapstructure:"refresh_delta"` - SigningMethod string `mapstructure:"signing_method"` - Issuer string `mapstructure:"issuer"` - Audience string `mapstructure:"audience"` +// SQLiteConfig holds SQLite specific configuration +type SQLiteConfig struct { + Path string } -// WeChatConfig holds all WeChat related configuration -type WeChatConfig struct { +// PostgresConfig holds PostgreSQL specific configuration +type PostgresConfig struct { + Host string + Port int + User string + Password string + DBName string + SSLMode string +} + +// SecurityConfig holds all security related configuration +type SecurityConfig struct { + EncryptionKey string + JWTSigningKey string + TokenExpiry time.Duration + RefreshTokenExpiry time.Duration +} + +// CORSConfig holds CORS related configuration +type CORSConfig struct { + AllowedOrigins []string + AllowedMethods []string + AllowedHeaders []string +} + +// WechatConfig holds WeChat related configuration +type WechatConfig struct { AppID string `mapstructure:"app_id"` AppSecret string `mapstructure:"app_secret"` } -// LoadConfig loads the configuration from file and environment variables -func LoadConfig() (*Config, error) { - // Set default values - setDefaults() +// LoadConfig reads configuration from file or environment variables +func LoadConfig(configPath string) (*Config, error) { + v := viper.New() - // Read config file - if err := viper.ReadInConfig(); err != nil { + v.SetConfigName("config") + v.SetConfigType("yaml") + v.AddConfigPath(configPath) + v.AddConfigPath(".") + + // Read environment variables + v.AutomaticEnv() + + // Allow environment variables to override config file + v.SetEnvPrefix("OTP") + v.BindEnv("security.encryption_key", "OTP_ENCRYPTION_KEY") + v.BindEnv("security.jwt_signing_key", "OTP_JWT_SIGNING_KEY") + v.BindEnv("database.postgres.password", "OTP_DB_PASSWORD") + v.BindEnv("wechat.app_id", "OTP_WECHAT_APPID") + v.BindEnv("wechat.app_secret", "OTP_WECHAT_SECRET") + + if err := v.ReadInConfig(); err != nil { return nil, fmt.Errorf("failed to read config file: %w", err) } var config Config - if err := viper.Unmarshal(&config); err != nil { + if err := v.Unmarshal(&config); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } - // Validate config + // Validate required configurations if err := validateConfig(&config); err != nil { - return nil, fmt.Errorf("invalid configuration: %w", err) + return nil, fmt.Errorf("config validation failed: %w", err) } return &config, nil } -// setDefaults sets default values for configuration -func setDefaults() { - // Server defaults - viper.SetDefault("server.port", 8080) - viper.SetDefault("server.read_timeout", "15s") - viper.SetDefault("server.write_timeout", "15s") - viper.SetDefault("server.shutdown_timeout", "5s") - viper.SetDefault("server.timeout", "30s") // Default request processing timeout - - // Database defaults - viper.SetDefault("database.driver", "sqlite3") - viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection - viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection - viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling - viper.SetDefault("database.skip_migration", false) - - // JWT defaults - viper.SetDefault("jwt.expire_delta", "24h") - viper.SetDefault("jwt.refresh_delta", "168h") // 7 days - viper.SetDefault("jwt.signing_method", "HS256") - viper.SetDefault("jwt.issuer", "otpm") - viper.SetDefault("jwt.audience", "otpm-client") -} - -// validateConfig validates the configuration +// validateConfig ensures all required configuration values are provided func validateConfig(config *Config) error { - if config.Server.Port < 1 || config.Server.Port > 65535 { - return fmt.Errorf("invalid port number: %d", config.Server.Port) + if config.Security.EncryptionKey == "" { + return fmt.Errorf("encryption key is required") + } + + if config.Security.JWTSigningKey == "" { + return fmt.Errorf("JWT signing key is required") } if config.Database.Driver == "" { return fmt.Errorf("database driver is required") } - if config.Database.DSN == "" { - return fmt.Errorf("database DSN is required") + switch config.Database.Driver { + case "sqlite3": + if config.Database.SQLite.Path == "" { + return fmt.Errorf("SQLite database path is required") + } + case "postgres": + if config.Database.Postgres.Host == "" || + config.Database.Postgres.User == "" || + config.Database.Postgres.Password == "" || + config.Database.Postgres.DBName == "" { + return fmt.Errorf("incomplete PostgreSQL configuration") + } + default: + return fmt.Errorf("unsupported database driver: %s", config.Database.Driver) } - if config.JWT.Secret == "" { - return fmt.Errorf("JWT secret is required") - } - - if config.WeChat.AppID == "" { + // Validate WeChat configuration + if config.Wechat.AppID == "" { return fmt.Errorf("WeChat AppID is required") } - - if config.WeChat.AppSecret == "" { + if config.Wechat.AppSecret == "" { return fmt.Errorf("WeChat AppSecret is required") } return nil } + +// GetDSN returns the appropriate database connection string based on the driver +func (c *DatabaseConfig) GetDSN() string { + switch c.Driver { + case "sqlite3": + return c.SQLite.Path + case "postgres": + return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + c.Postgres.Host, + c.Postgres.Port, + c.Postgres.User, + c.Postgres.Password, + c.Postgres.DBName, + c.Postgres.SSLMode) + default: + return "" + } +} diff --git a/database/db.go b/database/db.go deleted file mode 100644 index e882713..0000000 --- a/database/db.go +++ /dev/null @@ -1,212 +0,0 @@ -package database - -import ( - "context" - "database/sql" - "fmt" - "log" - "strings" - "time" - - "otpm/config" - - "github.com/jmoiron/sqlx" -) - -// DB wraps sqlx.DB to provide additional functionality -type DB struct { - *sqlx.DB -} - -// New creates a new database connection -func New(cfg *config.DatabaseConfig) (*DB, error) { - db, err := sqlx.Open(cfg.Driver, cfg.DSN) - if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) - } - - // Configure connection pool based on database type - if cfg.Driver == "sqlite3" { - // SQLite is a file-based database - simpler connection settings - db.SetMaxOpenConns(1) - db.SetMaxIdleConns(1) - db.SetConnMaxLifetime(0) // Connections don't need to be recycled - db.SetConnMaxIdleTime(0) - } else { - // For other databases (MySQL, PostgreSQL etc.) - db.SetMaxOpenConns(cfg.MaxOpenConns) - db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections - db.SetConnMaxLifetime(30 * time.Minute) - db.SetConnMaxIdleTime(5 * time.Minute) - } - - // Verify connection with timeout - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if err := db.PingContext(ctx); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) - } - - log.Println("Successfully connected to database") - return &DB{db}, nil -} - -// WithTx executes a function within a transaction with retry logic -func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error { - var maxRetries int - var lastErr error - - // Adjust retry settings based on database type - if db.DriverName() == "sqlite3" { - maxRetries = 5 // SQLite needs more retries due to busy timeouts - } else { - maxRetries = 3 - } - - // Default transaction options - opts := &sql.TxOptions{ - Isolation: sql.LevelReadCommitted, - } - - for attempt := 1; attempt <= maxRetries; attempt++ { - start := time.Now() - - tx, err := db.BeginTxx(ctx, opts) - if err != nil { - if isRetryableError(err) && attempt < maxRetries { - lastErr = err - time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff - continue - } - return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err) - } - - defer func() { - if p := recover(); p != nil { - tx.Rollback() - panic(p) - } - }() - - if err := fn(tx); err != nil { - if rbErr := tx.Rollback(); rbErr != nil { - log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err) - } - - if isRetryableError(err) && attempt < maxRetries { - lastErr = err - time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) - continue - } - return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err) - } - - // Log long-running transactions - if elapsed := time.Since(start); elapsed > 500*time.Millisecond { - log.Printf("Transaction completed in %v", elapsed) - } - - if err := tx.Commit(); err != nil { - if isRetryableError(err) && attempt < maxRetries { - lastErr = err - time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) - continue - } - return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err) - } - - return nil - } - - return lastErr -} - -// isRetryableError checks if an error is likely to succeed on retry -func isRetryableError(err error) bool { - if err == nil { - return false - } - - errStr := strings.ToLower(err.Error()) - return strings.Contains(errStr, "deadlock") || - strings.Contains(errStr, "timeout") || - strings.Contains(errStr, "try again") || - strings.Contains(errStr, "connection reset") || - strings.Contains(errStr, "busy") || - strings.Contains(errStr, "locked") -} - -// ExecContext executes a query with adaptive timeout -func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error { - // Set timeout based on query complexity - timeout := 5 * time.Second - if strings.Contains(strings.ToLower(query), "insert") || - strings.Contains(strings.ToLower(query), "update") || - strings.Contains(strings.ToLower(query), "delete") { - timeout = 10 * time.Second - } - - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - start := time.Now() - _, err := db.DB.ExecContext(ctx, query, args...) - elapsed := time.Since(start) - - // Log slow queries - if elapsed > timeout/2 { - log.Printf("Slow query execution detected: %s (took %v)", query, elapsed) - } - - if err != nil { - return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err) - } - - return nil -} - -// QueryRowContext executes a query that returns a single row with timeout -func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - return db.DB.QueryRowxContext(ctx, query, args...) -} - -// QueryContext executes a query that returns multiple rows with adaptive timeout -func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { - // Set timeout based on query complexity - timeout := 5 * time.Second - if strings.Contains(strings.ToLower(query), "join") || - strings.Contains(strings.ToLower(query), "group by") || - strings.Contains(strings.ToLower(query), "order by") { - timeout = 15 * time.Second - } - - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - start := time.Now() - rows, err := db.DB.QueryxContext(ctx, query, args...) - elapsed := time.Since(start) - - // Log slow queries - if elapsed > timeout/2 { - log.Printf("Slow query detected: %s (took %v)", query, elapsed) - } - - if err != nil { - return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err) - } - - return rows, nil -} - -// Close closes the database connection -func (db *DB) Close() error { - if err := db.DB.Close(); err != nil { - return fmt.Errorf("failed to close database connection: %w", err) - } - return nil -} diff --git a/database/init/otp.sql b/database/init/otp.sql deleted file mode 100644 index e0e1ba9..0000000 --- a/database/init/otp.sql +++ /dev/null @@ -1,26 +0,0 @@ -CREATE TABLE IF NOT EXISTS otp ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id VARCHAR(255) NOT NULL, - openid VARCHAR(255) NOT NULL, - name VARCHAR(100) NOT NULL, - issuer VARCHAR(255), - secret VARCHAR(255) NOT NULL, - algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1', - digits INTEGER NOT NULL DEFAULT 6, - period INTEGER NOT NULL DEFAULT 30, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE(user_id, name), - UNIQUE(openid) -); - --- Add index for faster lookups -CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id); -CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid); - --- Trigger to update the updated_at timestamp -CREATE TRIGGER IF NOT EXISTS update_otp_timestamp - AFTER UPDATE ON otp -BEGIN - UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; -END; \ No newline at end of file diff --git a/database/init/users.sql b/database/init/users.sql deleted file mode 100644 index ae4532a..0000000 --- a/database/init/users.sql +++ /dev/null @@ -1,6 +0,0 @@ -CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - openid VARCHAR(255) UNIQUE NOT NULL, - session_key VARCHAR(255) UNIQUE NOT NULL -); -CREATE UNIQUE INDEX idx_users_openid ON users(openid); \ No newline at end of file diff --git a/database/migration.go b/database/migration.go deleted file mode 100644 index df15a97..0000000 --- a/database/migration.go +++ /dev/null @@ -1,160 +0,0 @@ -package database - -import ( - "context" - "fmt" - "log" - "time" - - _ "embed" - - "github.com/jmoiron/sqlx" -) - -var ( - //go:embed init/users.sql - userTable string - //go:embed init/otp.sql - otpTable string -) - -// Migration represents a database migration -type Migration struct { - Name string - SQL string - Version int -} - -// Migrations is a list of all migrations -var Migrations = []Migration{ - { - Name: "Create users table", - SQL: userTable, - Version: 1, - }, - { - Name: "Create OTP table", - SQL: otpTable, - Version: 2, - }, -} - -// MigrationRecord represents a record in the migrations table -type MigrationRecord struct { - ID int `db:"id"` - Version int `db:"version"` - Name string `db:"name"` - AppliedAt time.Time `db:"applied_at"` -} - -// ensureMigrationsTable ensures that the migrations table exists -func ensureMigrationsTable(db *sqlx.DB) error { - query := ` - CREATE TABLE IF NOT EXISTS migrations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - version INTEGER NOT NULL, - name TEXT NOT NULL, - applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - ` - - _, err := db.Exec(query) - if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) - } - - return nil -} - -// getAppliedMigrations gets all applied migrations -func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) { - var records []MigrationRecord - if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil { - return nil, fmt.Errorf("failed to get applied migrations: %w", err) - } - - result := make(map[int]MigrationRecord) - for _, record := range records { - result[record.Version] = record - } - - return result, nil -} - -// Migrate runs all pending migrations -func Migrate(db *sqlx.DB, skipMigration bool) error { - if skipMigration { - log.Println("Skipping database migration as configured") - return nil - } - - // Ensure migrations table exists - if err := ensureMigrationsTable(db); err != nil { - return err - } - - // Get applied migrations - appliedMigrations, err := getAppliedMigrations(db) - if err != nil { - return err - } - - // Apply pending migrations - for _, migration := range Migrations { - if _, ok := appliedMigrations[migration.Version]; ok { - log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name) - continue - } - - log.Printf("Applying migration %d: %s", migration.Version, migration.Name) - - // Start a transaction for this migration - tx, err := db.Beginx() - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - - // Execute migration - if _, err := tx.Exec(migration.SQL); err != nil { - tx.Rollback() - return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err) - } - - // Record migration - if _, err := tx.Exec( - "INSERT INTO migrations (version, name) VALUES (?, ?)", - migration.Version, migration.Name, - ); err != nil { - tx.Rollback() - return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err) - } - - // Commit transaction - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err) - } - - log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name) - } - - return nil -} - -// MigrateWithContext runs all pending migrations with context -func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error { - // Create a channel to signal completion - done := make(chan error, 1) - - // Run migration in a goroutine - go func() { - done <- Migrate(db, skipMigration) - }() - - // Wait for migration to complete or context to be canceled - select { - case err := <-done: - return err - case <-ctx.Done(): - return fmt.Errorf("migration canceled: %w", ctx.Err()) - } -} diff --git a/docs/swagger.go b/docs/swagger.go deleted file mode 100644 index ccf05ac..0000000 --- a/docs/swagger.go +++ /dev/null @@ -1,663 +0,0 @@ -// Package docs provides API documentation using Swagger/OpenAPI -package docs - -import ( - "encoding/json" - "net/http" -) - -// SwaggerInfo holds the API information -var SwaggerInfo = struct { - Title string - Description string - Version string - Host string - BasePath string - Schemes []string -}{ - Title: "OTPM API", - Description: "API for One-Time Password Manager", - Version: "1.0.0", - Host: "localhost:8080", - BasePath: "/", - Schemes: []string{"http", "https"}, -} - -// SwaggerJSON returns the Swagger JSON -func SwaggerJSON() []byte { - swagger := map[string]interface{}{ - "swagger": "2.0", - "info": map[string]interface{}{ - "title": SwaggerInfo.Title, - "description": SwaggerInfo.Description, - "version": SwaggerInfo.Version, - }, - "host": SwaggerInfo.Host, - "basePath": SwaggerInfo.BasePath, - "schemes": SwaggerInfo.Schemes, - "paths": getPaths(), - "definitions": map[string]interface{}{ - "LoginRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "code": map[string]interface{}{ - "type": "string", - "description": "WeChat authorization code", - }, - }, - "required": []string{"code"}, - }, - "LoginResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "token": map[string]interface{}{ - "type": "string", - "description": "JWT token", - }, - "openid": map[string]interface{}{ - "type": "string", - "description": "WeChat OpenID", - }, - }, - }, - "CreateOTPRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "OTP name", - }, - "issuer": map[string]interface{}{ - "type": "string", - "description": "OTP issuer", - }, - "secret": map[string]interface{}{ - "type": "string", - "description": "OTP secret", - }, - "algorithm": map[string]interface{}{ - "type": "string", - "description": "OTP algorithm", - "enum": []string{"SHA1", "SHA256", "SHA512"}, - }, - "digits": map[string]interface{}{ - "type": "integer", - "description": "OTP digits", - "enum": []int{6, 8}, - }, - "period": map[string]interface{}{ - "type": "integer", - "description": "OTP period in seconds", - "enum": []int{30, 60}, - }, - }, - "required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"}, - }, - "OTP": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "OTP ID", - }, - "user_id": map[string]interface{}{ - "type": "string", - "description": "User ID", - }, - "name": map[string]interface{}{ - "type": "string", - "description": "OTP name", - }, - "issuer": map[string]interface{}{ - "type": "string", - "description": "OTP issuer", - }, - "algorithm": map[string]interface{}{ - "type": "string", - "description": "OTP algorithm", - }, - "digits": map[string]interface{}{ - "type": "integer", - "description": "OTP digits", - }, - "period": map[string]interface{}{ - "type": "integer", - "description": "OTP period in seconds", - }, - "created_at": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "Creation time", - }, - "updated_at": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "Last update time", - }, - }, - }, - "OTPCodeResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "code": map[string]interface{}{ - "type": "string", - "description": "OTP code", - }, - "expires_in": map[string]interface{}{ - "type": "integer", - "description": "Seconds until expiration", - }, - }, - }, - "VerifyOTPRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "code": map[string]interface{}{ - "type": "string", - "description": "OTP code to verify", - }, - }, - "required": []string{"code"}, - }, - "VerifyOTPResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "valid": map[string]interface{}{ - "type": "boolean", - "description": "Whether the code is valid", - }, - }, - }, - "UpdateOTPRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "OTP name", - }, - "issuer": map[string]interface{}{ - "type": "string", - "description": "OTP issuer", - }, - "algorithm": map[string]interface{}{ - "type": "string", - "description": "OTP algorithm", - "enum": []string{"SHA1", "SHA256", "SHA512"}, - }, - "digits": map[string]interface{}{ - "type": "integer", - "description": "OTP digits", - "enum": []int{6, 8}, - }, - "period": map[string]interface{}{ - "type": "integer", - "description": "OTP period in seconds", - "enum": []int{30, 60}, - }, - }, - }, - "ErrorResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "code": map[string]interface{}{ - "type": "integer", - "description": "Error code", - }, - "message": map[string]interface{}{ - "type": "string", - "description": "Error message", - }, - }, - }, - }, - "securityDefinitions": map[string]interface{}{ - "Bearer": map[string]interface{}{ - "type": "apiKey", - "name": "Authorization", - "in": "header", - "description": "JWT token with Bearer prefix", - }, - }, - } - - data, _ := json.MarshalIndent(swagger, "", " ") - return data -} - -// getPaths returns the API paths -func getPaths() map[string]interface{} { - return map[string]interface{}{ - "/login": map[string]interface{}{ - "post": map[string]interface{}{ - "summary": "Login with WeChat", - "description": "Login with WeChat authorization code", - "tags": []string{"auth"}, - "parameters": []map[string]interface{}{ - { - "name": "body", - "in": "body", - "description": "Login request", - "required": true, - "schema": map[string]interface{}{ - "$ref": "#/definitions/LoginRequest", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "Successful login", - "schema": map[string]interface{}{ - "$ref": "#/definitions/LoginResponse", - }, - }, - "400": map[string]interface{}{ - "description": "Invalid request", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - }, - "/verify-token": map[string]interface{}{ - "post": map[string]interface{}{ - "summary": "Verify token", - "description": "Verify JWT token", - "tags": []string{"auth"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "Token is valid", - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "valid": map[string]interface{}{ - "type": "boolean", - "description": "Whether the token is valid", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - }, - "/otp": map[string]interface{}{ - "get": map[string]interface{}{ - "summary": "List OTPs", - "description": "List all OTPs for the authenticated user", - "tags": []string{"otp"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "List of OTPs", - "schema": map[string]interface{}{ - "type": "array", - "items": map[string]interface{}{ - "$ref": "#/definitions/OTP", - }, - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - "post": map[string]interface{}{ - "summary": "Create OTP", - "description": "Create a new OTP", - "tags": []string{"otp"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "parameters": []map[string]interface{}{ - { - "name": "body", - "in": "body", - "description": "OTP creation request", - "required": true, - "schema": map[string]interface{}{ - "$ref": "#/definitions/CreateOTPRequest", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "OTP created", - "schema": map[string]interface{}{ - "$ref": "#/definitions/OTP", - }, - }, - "400": map[string]interface{}{ - "description": "Invalid request", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - }, - "/otp/{id}": map[string]interface{}{ - "put": map[string]interface{}{ - "summary": "Update OTP", - "description": "Update an existing OTP", - "tags": []string{"otp"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "description": "OTP ID", - "required": true, - "type": "string", - }, - { - "name": "body", - "in": "body", - "description": "OTP update request", - "required": true, - "schema": map[string]interface{}{ - "$ref": "#/definitions/UpdateOTPRequest", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "OTP updated", - "schema": map[string]interface{}{ - "$ref": "#/definitions/OTP", - }, - }, - "400": map[string]interface{}{ - "description": "Invalid request", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "404": map[string]interface{}{ - "description": "OTP not found", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - "delete": map[string]interface{}{ - "summary": "Delete OTP", - "description": "Delete an existing OTP", - "tags": []string{"otp"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "description": "OTP ID", - "required": true, - "type": "string", - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "OTP deleted", - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "Success message", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "404": map[string]interface{}{ - "description": "OTP not found", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - }, - "/otp/{id}/code": map[string]interface{}{ - "get": map[string]interface{}{ - "summary": "Get OTP code", - "description": "Get the current OTP code", - "tags": []string{"otp"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "description": "OTP ID", - "required": true, - "type": "string", - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "OTP code", - "schema": map[string]interface{}{ - "$ref": "#/definitions/OTPCodeResponse", - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "404": map[string]interface{}{ - "description": "OTP not found", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - }, - "/otp/{id}/verify": map[string]interface{}{ - "post": map[string]interface{}{ - "summary": "Verify OTP code", - "description": "Verify an OTP code", - "tags": []string{"otp"}, - "security": []map[string]interface{}{ - { - "Bearer": []string{}, - }, - }, - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "description": "OTP ID", - "required": true, - "type": "string", - }, - { - "name": "body", - "in": "body", - "description": "OTP verification request", - "required": true, - "schema": map[string]interface{}{ - "$ref": "#/definitions/VerifyOTPRequest", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "OTP verification result", - "schema": map[string]interface{}{ - "$ref": "#/definitions/VerifyOTPResponse", - }, - }, - "400": map[string]interface{}{ - "description": "Invalid request", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "401": map[string]interface{}{ - "description": "Unauthorized", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "404": map[string]interface{}{ - "description": "OTP not found", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - "500": map[string]interface{}{ - "description": "Internal server error", - "schema": map[string]interface{}{ - "$ref": "#/definitions/ErrorResponse", - }, - }, - }, - }, - }, - "/health": map[string]interface{}{ - "get": map[string]interface{}{ - "summary": "Health check", - "description": "Check if the API is healthy", - "tags": []string{"system"}, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "API is healthy", - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "status": map[string]interface{}{ - "type": "string", - "description": "Health status", - }, - "time": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "Current time", - }, - }, - }, - }, - }, - }, - }, - "/metrics": map[string]interface{}{ - "get": map[string]interface{}{ - "summary": "Metrics", - "description": "Get application metrics", - "tags": []string{"system"}, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "Application metrics", - }, - }, - }, - }, - "/swagger.json": map[string]interface{}{ - "get": map[string]interface{}{ - "summary": "Swagger JSON", - "description": "Get Swagger JSON", - "tags": []string{"system"}, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "Swagger JSON", - }, - }, - }, - }, - } -} - -// Handler returns an HTTP handler for Swagger JSON -func Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write(SwaggerJSON()) - } -} diff --git a/go.mod b/go.mod deleted file mode 100644 index a55608f..0000000 --- a/go.mod +++ /dev/null @@ -1,50 +0,0 @@ -module otpm - -go 1.23.0 - -toolchain go1.23.9 - -require ( - github.com/go-playground/validator/v10 v10.26.0 - github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/google/uuid v1.6.0 - github.com/jmoiron/sqlx v1.4.0 - github.com/julienschmidt/httprouter v1.3.0 - github.com/prometheus/client_golang v1.22.0 - github.com/spf13/viper v1.19.0 - golang.org/x/crypto v0.38.0 -) - -require ( - github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.8 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/hashicorp/hcl v1.0.0 // indirect - github.com/leodido/go-urn v1.4.0 // indirect - github.com/magiconair/properties v1.8.7 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.62.0 // indirect - github.com/prometheus/procfs v0.15.1 // indirect - github.com/sagikazarmark/locafero v0.4.0 // indirect - github.com/sagikazarmark/slog-shim v0.1.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect - github.com/subosito/gotenv v1.6.0 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect - golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect - golang.org/x/net v0.34.0 // indirect - golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.25.0 // indirect - google.golang.org/protobuf v1.36.5 // indirect - gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/go.sum b/go.sum deleted file mode 100644 index c30de80..0000000 --- a/go.sum +++ /dev/null @@ -1,124 +0,0 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= -github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= -github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= -github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= -github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= -github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= -github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= -github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= -github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= -github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= -github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= -github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= -github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= -github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= -github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= -github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= -github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= -github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= -github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= -github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= -github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= -github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= -go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= -golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= -golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= -gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go deleted file mode 100644 index f79979f..0000000 --- a/handlers/auth_handler.go +++ /dev/null @@ -1,156 +0,0 @@ -package handlers - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "time" - - "otpm/api" - "otpm/services" - - "github.com/golang-jwt/jwt" - "github.com/julienschmidt/httprouter" -) - -// AuthHandler handles authentication related requests -type AuthHandler struct { - authService *services.AuthService -} - -// NewAuthHandler creates a new AuthHandler -func NewAuthHandler(authService *services.AuthService) *AuthHandler { - return &AuthHandler{ - authService: authService, - } -} - -// LoginRequest represents a login request -type LoginRequest struct { - Code string `json:"code" validate:"required,min=32,max=128"` -} - -// LoginResponse represents a login response -type LoginResponse struct { - Token string `json:"token"` - OpenID string `json:"openid"` -} - -// TokenRequest represents a token verification request -type TokenRequest struct { - Token string `validate:"required,min=32"` -} - -// Login handles WeChat login -func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - start := time.Now() - - // Limit request body size to prevent DOS - r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request - - // Parse and validate request - var req LoginRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - fmt.Sprintf("Invalid request body: %v", err)) - log.Printf("Login request parse error: %v", err) - return - } - - // Validate using validator - if err := api.Validate.Struct(req); err != nil { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - fmt.Sprintf("Invalid request parameters: %v", err)) - log.Printf("Login request validation failed: %v", err) - return - } - - // Login with WeChat code - token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code) - if err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - log.Printf("Login failed for code %s: %v", req.Code, err) - return - } - - // Log successful login - log.Printf("Login successful for code %s (took %v)", - req.Code, time.Since(start)) - - // Return token - api.NewResponseWriter(w).WriteSuccess(LoginResponse{ - Token: token, - }) -} - -// VerifyToken handles token verification -func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - start := time.Now() - - // Get token from Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Authorization header is required") - log.Printf("Token verification failed: missing Authorization header") - return - } - - // Validate token format - if len(authHeader) < 7 || authHeader[:7] != "Bearer " { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Invalid token format. Expected 'Bearer '") - log.Printf("Token verification failed: invalid token format") - return - } - - token := authHeader[7:] - - // Validate token using validator - tokenReq := TokenRequest{Token: token} - if err := api.Validate.Struct(tokenReq); err != nil { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Invalid token format") - log.Printf("Token verification failed: %v", err) - return - } - - // Validate token - claims, err := h.authService.ValidateToken(token) - if err != nil { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - log.Printf("Token verification failed for token %s: %v", - maskToken(token), err) // Mask token in logs - return - } - - // Log successful verification - userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string) - if !ok { - log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start)) - } else { - log.Printf("Token verified for user %s (took %v)", userID, time.Since(start)) - } - - // Token is valid - api.NewResponseWriter(w).WriteSuccess(map[string]bool{ - "valid": true, - }) -} - -// maskToken masks sensitive parts of token for logging -func maskToken(token string) string { - if len(token) < 8 { - return "****" - } - return token[:4] + "****" + token[len(token)-4:] -} - -// Routes returns all routes for the auth handler -func (h *AuthHandler) Routes() map[string]httprouter.Handle { - return map[string]httprouter.Handle{ - "/api/login": h.Login, - "/api/verify-token": h.VerifyToken, - } -} diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go deleted file mode 100644 index c1d99bb..0000000 --- a/handlers/otp_handler.go +++ /dev/null @@ -1,114 +0,0 @@ -package handlers - -import ( - "encoding/json" - "net/http" - - "github.com/julienschmidt/httprouter" - - "otpm/api" - "otpm/middleware" - "otpm/models" - "otpm/services" -) - -// OTPHandler handles OTP-related HTTP requests -type OTPHandler struct { - otpService *services.OTPService -} - -// NewOTPHandler creates a new OTPHandler -func NewOTPHandler(otpService *services.OTPService) *OTPHandler { - return &OTPHandler{ - otpService: otpService, - } -} - -// Routes returns the routes for OTP operations -func (h *OTPHandler) Routes() map[string]httprouter.Handle { - return map[string]httprouter.Handle{ - "POST /api/otp": h.CreateOTP, - "GET /api/otps": h.ListOTPs, - "GET /api/otp/:id": h.GetOTP, - } -} - -// CreateOTP handles the creation of a new OTP -func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - // Get user ID from context - userID, ok := r.Context().Value(middleware.UserIDKey).(string) - if !ok { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - return - } - - // Parse request body - var params models.OTPParams - if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { - api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body")) - return - } - - // Validate request - if err := api.Validate.Struct(params); err != nil { - api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error())) - return - } - - // Create OTP - otp, err := h.otpService.CreateOTP(r.Context(), userID, params) - if err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - return - } - - // Return response - api.NewResponseWriter(w).WriteSuccess(otp) -} - -// ListOTPs handles listing all OTPs for a user -func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - // Get user ID from context - userID, ok := r.Context().Value(middleware.UserIDKey).(string) - if !ok { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - return - } - - // Get OTPs - otps, err := h.otpService.ListOTPs(r.Context(), userID) - if err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - return - } - - // Return response - api.NewResponseWriter(w).WriteSuccess(otps) -} - -// GetOTP handles getting a specific OTP -func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - // Get user ID from context - userID, ok := r.Context().Value(middleware.UserIDKey).(string) - if !ok { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - return - } - - // Get OTP ID from URL - otpID := ps.ByName("id") - if otpID == "" { - api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID")) - return - } - - // Get OTP - otp, err := h.otpService.GetOTP(r.Context(), otpID, userID) - if err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - return - } - - // Return response - api.NewResponseWriter(w).WriteSuccess(otp) -} diff --git a/init/postgresql/init.sql b/init/postgresql/init.sql new file mode 100644 index 0000000..f734604 --- /dev/null +++ b/init/postgresql/init.sql @@ -0,0 +1,51 @@ +-- 创建tokens表 +CREATE TABLE IF NOT EXISTS tokens ( + id VARCHAR(255) NOT NULL, -- token的唯一标识符 + user_id VARCHAR(255) NOT NULL, -- 用户ID + issuer VARCHAR(255) NOT NULL, -- 令牌发行者 + account VARCHAR(255) NOT NULL, -- 账户名称 + secret TEXT NOT NULL, -- 密钥 + type VARCHAR(10) NOT NULL, -- 令牌类型(totp/hotp) + counter INTEGER, -- HOTP计数器(可选) + period INTEGER NOT NULL, -- TOTP周期(秒) + digits INTEGER NOT NULL, -- 验证码位数 + algo VARCHAR(10) NOT NULL, -- 使用的哈希算法 + timestamp BIGINT NOT NULL, -- 最后更新时间戳 + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id, user_id) +); + +-- 创建更新时间戳的触发器 +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ language 'plpgsql'; + +CREATE TRIGGER update_tokens_updated_at + BEFORE UPDATE ON tokens + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +-- 创建索引 +CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id); +CREATE INDEX IF NOT EXISTS idx_tokens_timestamp ON tokens(timestamp); + +-- 添加注释 +COMMENT ON TABLE tokens IS 'OTP令牌数据表'; +COMMENT ON COLUMN tokens.id IS '令牌的唯一标识符'; +COMMENT ON COLUMN tokens.user_id IS '用户ID'; +COMMENT ON COLUMN tokens.issuer IS '令牌发行者'; +COMMENT ON COLUMN tokens.account IS '账户名称'; +COMMENT ON COLUMN tokens.secret IS '密钥'; +COMMENT ON COLUMN tokens.type IS '令牌类型(totp/hotp)'; +COMMENT ON COLUMN tokens.counter IS 'HOTP计数器(可选)'; +COMMENT ON COLUMN tokens.period IS 'TOTP周期(秒)'; +COMMENT ON COLUMN tokens.digits IS '验证码位数'; +COMMENT ON COLUMN tokens.algo IS '使用的哈希算法'; +COMMENT ON COLUMN tokens.timestamp IS '最后更新时间戳'; +COMMENT ON COLUMN tokens.created_at IS '创建时间'; +COMMENT ON COLUMN tokens.updated_at IS '最后更新时间'; \ No newline at end of file diff --git a/init/sqlite3/init.sql b/init/sqlite3/init.sql new file mode 100644 index 0000000..6f9f3fa --- /dev/null +++ b/init/sqlite3/init.sql @@ -0,0 +1,50 @@ +-- SQLite3 initialization SQL + +-- Enable WAL mode for better concurrency (simple performance boost) +PRAGMA journal_mode = WAL; +PRAGMA synchronous = NORMAL; + +-- Enable foreign key support +PRAGMA foreign_keys = ON; + +-- 创建tokens表 +CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + issuer TEXT NOT NULL, + account TEXT NOT NULL, + secret TEXT NOT NULL CHECK (length(secret) >= 16 AND secret REGEXP '^[A-Z2-7]+=*$'), + type TEXT NOT NULL CHECK (type IN ('HOTP', 'TOTP')), + counter INTEGER CHECK ( + (type = 'HOTP' AND counter >= 0) OR + (type = 'TOTP' AND counter IS NULL) + ), + period INTEGER DEFAULT 30 CHECK ( + (type = 'TOTP' AND period >= 30) OR + (type = 'HOTP' AND period IS NULL) + ), + digits INTEGER NOT NULL DEFAULT 6 CHECK (digits IN (6, 8)), + algo TEXT NOT NULL DEFAULT 'SHA1' CHECK (algo IN ('SHA1', 'SHA256', 'SHA512')), + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + updated_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + UNIQUE(user_id, issuer, account) +); + +-- 基本索引 +CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id); +CREATE INDEX IF NOT EXISTS idx_tokens_lookup ON tokens(user_id, issuer, account); +CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'HOTP'; +CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'TOTP'; + +-- 简化统计视图 +CREATE VIEW IF NOT EXISTS v_token_stats AS +SELECT + user_id, + COUNT(*) as total_tokens, + SUM(type = 'HOTP') as hotp_count, + SUM(type = 'TOTP') as totp_count +FROM tokens +GROUP BY user_id; + +-- 设置版本号 +PRAGMA user_version = 1; \ No newline at end of file diff --git a/logger/logger.go b/logger/logger.go deleted file mode 100644 index d7d76e4..0000000 --- a/logger/logger.go +++ /dev/null @@ -1,204 +0,0 @@ -package logger - -import ( - "context" - "fmt" - "io" - "os" - "runtime" - "strings" - "time" - - "github.com/google/uuid" -) - -// Level represents a log level -type Level int - -const ( - // DEBUG level - DEBUG Level = iota - // INFO level - INFO - // WARN level - WARN - // ERROR level - ERROR - // FATAL level - FATAL -) - -// String returns the string representation of the log level -func (l Level) String() string { - switch l { - case DEBUG: - return "DEBUG" - case INFO: - return "INFO" - case WARN: - return "WARN" - case ERROR: - return "ERROR" - case FATAL: - return "FATAL" - default: - return "UNKNOWN" - } -} - -// Logger represents a logger -type Logger struct { - level Level - output io.Writer -} - -// contextKey is a type for context keys -type contextKey string - -// requestIDKey is the key for request ID in context -const requestIDKey = contextKey("request_id") - -// New creates a new logger -func New(level Level, output io.Writer) *Logger { - if output == nil { - output = os.Stdout - } - return &Logger{ - level: level, - output: output, - } -} - -// WithLevel creates a new logger with the specified level -func (l *Logger) WithLevel(level Level) *Logger { - return &Logger{ - level: level, - output: l.output, - } -} - -// WithOutput creates a new logger with the specified output -func (l *Logger) WithOutput(output io.Writer) *Logger { - return &Logger{ - level: l.level, - output: output, - } -} - -// log logs a message with the specified level -func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) { - if level < l.level { - return - } - - // Get request ID from context - requestID := getRequestID(ctx) - - // Get caller information - _, file, line, ok := runtime.Caller(2) - if !ok { - file = "unknown" - line = 0 - } - // Extract just the filename - if idx := strings.LastIndex(file, "/"); idx >= 0 { - file = file[idx+1:] - } - - // Format message - message := fmt.Sprintf(format, args...) - - // Format log entry - timestamp := time.Now().Format(time.RFC3339) - logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n", - timestamp, level.String(), file, line, requestID, message) - - // Write log entry - _, _ = l.output.Write([]byte(logEntry)) - - // Exit if fatal - if level == FATAL { - os.Exit(1) - } -} - -// Debug logs a debug message -func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) { - l.log(ctx, DEBUG, format, args...) -} - -// Info logs an info message -func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) { - l.log(ctx, INFO, format, args...) -} - -// Warn logs a warning message -func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) { - l.log(ctx, WARN, format, args...) -} - -// Error logs an error message -func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) { - l.log(ctx, ERROR, format, args...) -} - -// Fatal logs a fatal message and exits -func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) { - l.log(ctx, FATAL, format, args...) -} - -// WithRequestID adds a request ID to the context -func WithRequestID(ctx context.Context) context.Context { - requestID := uuid.New().String() - return context.WithValue(ctx, requestIDKey, requestID) -} - -// GetRequestID gets the request ID from the context -func GetRequestID(ctx context.Context) string { - return getRequestID(ctx) -} - -// getRequestID gets the request ID from the context -func getRequestID(ctx context.Context) string { - if ctx == nil { - return "-" - } - requestID, ok := ctx.Value(requestIDKey).(string) - if !ok { - return "-" - } - return requestID -} - -// Default logger -var defaultLogger = New(INFO, os.Stdout) - -// SetDefaultLogger sets the default logger -func SetDefaultLogger(logger *Logger) { - defaultLogger = logger -} - -// Debug logs a debug message using the default logger -func Debug(ctx context.Context, format string, args ...interface{}) { - defaultLogger.Debug(ctx, format, args...) -} - -// Info logs an info message using the default logger -func Info(ctx context.Context, format string, args ...interface{}) { - defaultLogger.Info(ctx, format, args...) -} - -// Warn logs a warning message using the default logger -func Warn(ctx context.Context, format string, args ...interface{}) { - defaultLogger.Warn(ctx, format, args...) -} - -// Error logs an error message using the default logger -func Error(ctx context.Context, format string, args ...interface{}) { - defaultLogger.Error(ctx, format, args...) -} - -// Fatal logs a fatal message and exits using the default logger -func Fatal(ctx context.Context, format string, args ...interface{}) { - defaultLogger.Fatal(ctx, format, args...) -} diff --git a/main.go b/main.go deleted file mode 100644 index 1b34398..0000000 --- a/main.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import "otpm/cmd" - -func main() { - cmd.Execute() -} diff --git a/metrics/metrics.go b/metrics/metrics.go deleted file mode 100644 index 8e1d18a..0000000 --- a/metrics/metrics.go +++ /dev/null @@ -1,193 +0,0 @@ -package metrics - -import ( - "fmt" - "net/http" - "sync" - "time" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -var ( - // Default metrics - requestDuration = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "http_request_duration_seconds", - Help: "Duration of HTTP requests in seconds", - Buckets: prometheus.DefBuckets, - }, - []string{"method", "path", "status"}, - ) - - requestTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "http_requests_total", - Help: "Total number of HTTP requests", - }, - []string{"method", "path", "status"}, - ) - - otpGenerationTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "otp_generation_total", - Help: "Total number of OTP generations", - }, - []string{"user_id", "otp_id"}, - ) - - otpVerificationTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "otp_verification_total", - Help: "Total number of OTP verifications", - }, - []string{"user_id", "otp_id", "success"}, - ) - - activeUsers = prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "active_users", - Help: "Number of active users", - }, - ) - - cacheHits = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "cache_hits_total", - Help: "Total number of cache hits", - }, - []string{"cache"}, - ) - - cacheMisses = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "cache_misses_total", - Help: "Total number of cache misses", - }, - []string{"cache"}, - ) -) - -func init() { - // Register metrics with prometheus - prometheus.MustRegister( - requestDuration, - requestTotal, - otpGenerationTotal, - otpVerificationTotal, - activeUsers, - cacheHits, - cacheMisses, - ) -} - -// MetricsService provides metrics functionality -type MetricsService struct { - activeUsersMutex sync.RWMutex - activeUserIDs map[string]bool -} - -// NewMetricsService creates a new MetricsService -func NewMetricsService() *MetricsService { - return &MetricsService{ - activeUserIDs: make(map[string]bool), - } -} - -// Handler returns an HTTP handler for metrics -func (s *MetricsService) Handler() http.Handler { - return promhttp.Handler() -} - -// RecordRequest records metrics for an HTTP request -func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) { - labels := prometheus.Labels{ - "method": method, - "path": path, - "status": fmt.Sprintf("%d", status), - } - - requestDuration.With(labels).Observe(duration.Seconds()) - requestTotal.With(labels).Inc() -} - -// RecordOTPGeneration records metrics for OTP generation -func (s *MetricsService) RecordOTPGeneration(userID, otpID string) { - otpGenerationTotal.With(prometheus.Labels{ - "user_id": userID, - "otp_id": otpID, - }).Inc() -} - -// RecordOTPVerification records metrics for OTP verification -func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) { - otpVerificationTotal.With(prometheus.Labels{ - "user_id": userID, - "otp_id": otpID, - "success": fmt.Sprintf("%t", success), - }).Inc() -} - -// RecordUserActivity records user activity -func (s *MetricsService) RecordUserActivity(userID string) { - s.activeUsersMutex.Lock() - defer s.activeUsersMutex.Unlock() - - if !s.activeUserIDs[userID] { - s.activeUserIDs[userID] = true - activeUsers.Inc() - } -} - -// RecordUserInactivity records user inactivity -func (s *MetricsService) RecordUserInactivity(userID string) { - s.activeUsersMutex.Lock() - defer s.activeUsersMutex.Unlock() - - if s.activeUserIDs[userID] { - delete(s.activeUserIDs, userID) - activeUsers.Dec() - } -} - -// RecordCacheHit records a cache hit -func (s *MetricsService) RecordCacheHit(cache string) { - cacheHits.With(prometheus.Labels{ - "cache": cache, - }).Inc() -} - -// RecordCacheMiss records a cache miss -func (s *MetricsService) RecordCacheMiss(cache string) { - cacheMisses.With(prometheus.Labels{ - "cache": cache, - }).Inc() -} - -// Middleware creates a middleware that records request metrics -func (s *MetricsService) Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - // Create response writer that captures status code - rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} - - // Call next handler - next.ServeHTTP(rw, r) - - // Record metrics - s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start)) - }) -} - -// responseWriter wraps http.ResponseWriter to capture status code -type responseWriter struct { - http.ResponseWriter - status int -} - -func (rw *responseWriter) WriteHeader(code int) { - rw.status = code - rw.ResponseWriter.WriteHeader(code) -} diff --git a/middleware/middleware.go b/middleware/middleware.go deleted file mode 100644 index 6afce13..0000000 --- a/middleware/middleware.go +++ /dev/null @@ -1,353 +0,0 @@ -package middleware - -import ( - "context" - "encoding/json" - "fmt" - "log" - "math/rand" - "net/http" - "runtime/debug" - "strings" - "time" - - "github.com/golang-jwt/jwt" -) - -// Response represents a standard API response -type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -// ErrorResponse sends a JSON error response -func ErrorResponse(w http.ResponseWriter, code int, message string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - json.NewEncoder(w).Encode(Response{ - Code: code, - Message: message, - }) -} - -// SuccessResponse sends a JSON success response -func SuccessResponse(w http.ResponseWriter, data interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(Response{ - Code: http.StatusOK, - Message: "success", - Data: data, - }) -} - -// Logger is a middleware that logs request details with structured format -func Logger(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - requestID := r.Header.Get("X-Request-ID") - if requestID == "" { - requestID = generateRequestID() - r.Header.Set("X-Request-ID", requestID) - } - - // Create a custom response writer to capture status code - rw := &responseWriter{ - ResponseWriter: w, - status: http.StatusOK, - } - - // Process request - next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID))) - - // Log structured request details - log.Printf( - "method=%s path=%s status=%d duration=%s ip=%s request_id=%s", - r.Method, - r.URL.Path, - rw.status, - time.Since(start).String(), - r.RemoteAddr, - requestID, - ) - }) -} - -// generateRequestID creates a unique request identifier -func generateRequestID() string { - return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) -} - -// randomString generates a random string of given length -func randomString(length int) string { - const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, length) - for i := range b { - b[i] = charset[rand.Intn(len(charset))] - } - return string(b) -} - -// Recover is a middleware that recovers from panics with detailed logging -func Recover(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - if err := recover(); err != nil { - // Get request ID from context - requestID := "" - if ctx := r.Context(); ctx != nil { - if id, ok := ctx.Value("request_id").(string); ok { - requestID = id - } - } - - // Log error with stack trace and request context - log.Printf( - "panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s", - err, - requestID, - r.Method, - r.URL.Path, - r.RemoteAddr, - debug.Stack(), - ) - - // Determine error type - var message string - var status int - - switch e := err.(type) { - case error: - message = e.Error() - if isClientError(e) { - status = http.StatusBadRequest - } else { - status = http.StatusInternalServerError - } - case string: - message = e - status = http.StatusInternalServerError - default: - message = "Internal Server Error" - status = http.StatusInternalServerError - } - - ErrorResponse(w, status, message) - } - }() - next.ServeHTTP(w, r) - }) -} - -// isClientError checks if error should be treated as client error -func isClientError(err error) bool { - // Add more client error types as needed - return strings.Contains(err.Error(), "validation") || - strings.Contains(err.Error(), "invalid") || - strings.Contains(err.Error(), "missing") -} - -// CORS is a middleware that handles CORS -func CORS(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - next.ServeHTTP(w, r) - }) -} - -// Timeout is a middleware that safely handles request timeouts -func Timeout(duration time.Duration) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), duration) - defer cancel() - - // Use buffered channels to prevent goroutine leaks - done := make(chan struct{}, 1) - panicChan := make(chan interface{}, 1) - - // Track request processing in goroutine - go func() { - defer func() { - if p := recover(); p != nil { - panicChan <- p - } - }() - next.ServeHTTP(w, r.WithContext(ctx)) - done <- struct{}{} - }() - - // Wait for completion, timeout or panic - select { - case <-done: - return - case p := <-panicChan: - panic(p) // Re-throw panic to be caught by Recover middleware - case <-ctx.Done(): - // Get request context for logging - requestID := "" - if ctx := r.Context(); ctx != nil { - if id, ok := ctx.Value("request_id").(string); ok { - requestID = id - } - } - - // Log timeout details - log.Printf( - "request_timeout: request_id=%s method=%s path=%s timeout=%s", - requestID, - r.Method, - r.URL.Path, - duration.String(), - ) - - // Send timeout response - ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf( - "Request timed out after %s", duration.String(), - )) - } - }) - } -} - -// Auth is a middleware that validates JWT tokens with enhanced security -func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Get request ID for logging - requestID := "" - if ctx := r.Context(); ctx != nil { - if id, ok := ctx.Value("request_id").(string); ok { - requestID = id - } - } - - // Get token from Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID) - ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required") - return - } - - // Validate header format - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID) - ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer '") - return - } - - tokenString := parts[1] - - // Parse and validate token - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // Validate signing method - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return []byte(jwtSecret), nil - }) - - if err != nil { - log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err) - ErrorResponse(w, http.StatusUnauthorized, "Invalid token") - return - } - - if !token.Valid { - log.Printf("auth_failed: request_id=%s error=invalid_token", requestID) - ErrorResponse(w, http.StatusUnauthorized, "Invalid token") - return - } - - // Validate claims - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID) - ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims") - return - } - - // Check required claims - userID, ok := claims["user_id"].(string) - if !ok || userID == "" { - log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID) - ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token") - return - } - - // Check token expiration - if exp, ok := claims["exp"].(float64); ok { - if time.Now().Unix() > int64(exp) { - log.Printf("auth_failed: request_id=%s error=token_expired", requestID) - ErrorResponse(w, http.StatusUnauthorized, "Token has expired") - return - } - } - - // Check required roles if specified - if len(requiredRoles) > 0 { - roles, ok := claims["roles"].([]interface{}) - if !ok { - log.Printf("auth_failed: request_id=%s error=missing_roles", requestID) - ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles") - return - } - - hasRequiredRole := false - for _, requiredRole := range requiredRoles { - for _, role := range roles { - if r, ok := role.(string); ok && r == requiredRole { - hasRequiredRole = true - break - } - } - } - - if !hasRequiredRole { - log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID) - ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions") - return - } - } - - // Add claims to context - ctx := r.Context() - ctx = context.WithValue(ctx, "user_id", userID) - ctx = context.WithValue(ctx, "claims", claims) - - log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -// responseWriter is a custom response writer that captures the status code -type responseWriter struct { - http.ResponseWriter - status int -} - -func (rw *responseWriter) WriteHeader(code int) { - rw.status = code - rw.ResponseWriter.WriteHeader(code) -} - -// GetUserID gets the user ID from the request context -func GetUserID(r *http.Request) (string, error) { - userID, ok := r.Context().Value("user_id").(string) - if !ok { - return "", fmt.Errorf("user ID not found in context") - } - return userID, nil -} diff --git a/miniprogram-example/app.js b/miniprogram-example/app.js deleted file mode 100644 index 34eba08..0000000 --- a/miniprogram-example/app.js +++ /dev/null @@ -1,50 +0,0 @@ -// app.js -App({ - onLaunch() { - // 检查更新 - if (wx.canIUse('getUpdateManager')) { - const updateManager = wx.getUpdateManager(); - updateManager.onCheckForUpdate(function (res) { - if (res.hasUpdate) { - updateManager.onUpdateReady(function () { - wx.showModal({ - title: '更新提示', - content: '新版本已经准备好,是否重启应用?', - success: function (res) { - if (res.confirm) { - updateManager.applyUpdate(); - } - } - }); - }); - - updateManager.onUpdateFailed(function () { - wx.showModal({ - title: '更新提示', - content: '新版本下载失败,请检查网络后重试', - showCancel: false - }); - }); - } - }); - } - - // 获取系统信息 - try { - const systemInfo = wx.getSystemInfoSync(); - this.globalData.systemInfo = systemInfo; - - // 计算安全区域 - const { screenHeight, safeArea } = systemInfo; - this.globalData.safeAreaBottom = screenHeight - safeArea.bottom; - } catch (e) { - console.error('获取系统信息失败', e); - } - }, - - globalData: { - userInfo: null, - systemInfo: {}, - safeAreaBottom: 0 - } -}); \ No newline at end of file diff --git a/miniprogram-example/app.json b/miniprogram-example/app.json deleted file mode 100644 index 89f79ae..0000000 --- a/miniprogram-example/app.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "pages": [ - "pages/login/login", - "pages/otp-list/index", - "pages/otp-add/index" - ], - "window": { - "backgroundTextStyle": "light", - "navigationBarBackgroundColor": "#fff", - "navigationBarTitleText": "OTPM", - "navigationBarTextStyle": "black", - "backgroundColor": "#F8F8F8" - }, - "permission": { - "scope.camera": { - "desc": "需要使用相机扫描二维码" - } - }, - "usingComponents": {}, - "style": "v2", - "sitemapLocation": "sitemap.json", - "lazyCodeLoading": "requiredComponents" -} \ No newline at end of file diff --git a/miniprogram-example/app.wxss b/miniprogram-example/app.wxss deleted file mode 100644 index fb73b4a..0000000 --- a/miniprogram-example/app.wxss +++ /dev/null @@ -1,238 +0,0 @@ -/**app.wxss**/ -page { - --primary-color: #1890ff; - --danger-color: #ff4d4f; - --success-color: #52c41a; - --warning-color: #faad14; - --text-color: #333333; - --text-color-secondary: #666666; - --text-color-light: #999999; - --border-color: #e8e8e8; - --background-color: #f8f8f8; - --border-radius: 8rpx; - --safe-area-bottom: env(safe-area-inset-bottom); - - font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica, - Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei', - sans-serif; - font-size: 28rpx; - line-height: 1.5; - color: var(--text-color); - background-color: var(--background-color); -} - -/* 清除默认样式 */ -button { - padding: 0; - margin: 0; - background: none; - border: none; - text-align: left; - line-height: inherit; - overflow: visible; -} - -button::after { - border: none; -} - -/* 通用样式类 */ -.container { - min-height: 100vh; - box-sizing: border-box; -} - -.safe-area-bottom { - padding-bottom: var(--safe-area-bottom); -} - -.flex-center { - display: flex; - align-items: center; - justify-content: center; -} - -.flex-between { - display: flex; - align-items: center; - justify-content: space-between; -} - -.flex-column { - display: flex; - flex-direction: column; -} - -.text-primary { - color: var(--primary-color); -} - -.text-danger { - color: var(--danger-color); -} - -.text-success { - color: var(--success-color); -} - -.text-warning { - color: var(--warning-color); -} - -.text-secondary { - color: var(--text-color-secondary); -} - -.text-light { - color: var(--text-color-light); -} - -.text-center { - text-align: center; -} - -.text-left { - text-align: left; -} - -.text-right { - text-align: right; -} - -.text-bold { - font-weight: bold; -} - -.text-ellipsis { - overflow: hidden; - text-overflow: ellipsis; - white-space: nowrap; -} - -.bg-white { - background-color: #ffffff; -} - -.bg-primary { - background-color: var(--primary-color); -} - -.bg-danger { - background-color: var(--danger-color); -} - -.bg-success { - background-color: var(--success-color); -} - -.bg-warning { - background-color: var(--warning-color); -} - -.shadow { - box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); -} - -.rounded { - border-radius: var(--border-radius); -} - -.border { - border: 2rpx solid var(--border-color); -} - -.border-top { - border-top: 2rpx solid var(--border-color); -} - -.border-bottom { - border-bottom: 2rpx solid var(--border-color); -} - -/* 动画类 */ -.fade-in { - animation: fadeIn 0.3s ease-in-out; -} - -.fade-out { - animation: fadeOut 0.3s ease-in-out; -} - -@keyframes fadeIn { - from { - opacity: 0; - } - to { - opacity: 1; - } -} - -@keyframes fadeOut { - from { - opacity: 1; - } - to { - opacity: 0; - } -} - -/* 间距类 */ -.m-0 { margin: 0; } -.m-1 { margin: 10rpx; } -.m-2 { margin: 20rpx; } -.m-3 { margin: 30rpx; } -.m-4 { margin: 40rpx; } - -.mt-0 { margin-top: 0; } -.mt-1 { margin-top: 10rpx; } -.mt-2 { margin-top: 20rpx; } -.mt-3 { margin-top: 30rpx; } -.mt-4 { margin-top: 40rpx; } - -.mb-0 { margin-bottom: 0; } -.mb-1 { margin-bottom: 10rpx; } -.mb-2 { margin-bottom: 20rpx; } -.mb-3 { margin-bottom: 30rpx; } -.mb-4 { margin-bottom: 40rpx; } - -.ml-0 { margin-left: 0; } -.ml-1 { margin-left: 10rpx; } -.ml-2 { margin-left: 20rpx; } -.ml-3 { margin-left: 30rpx; } -.ml-4 { margin-left: 40rpx; } - -.mr-0 { margin-right: 0; } -.mr-1 { margin-right: 10rpx; } -.mr-2 { margin-right: 20rpx; } -.mr-3 { margin-right: 30rpx; } -.mr-4 { margin-right: 40rpx; } - -.p-0 { padding: 0; } -.p-1 { padding: 10rpx; } -.p-2 { padding: 20rpx; } -.p-3 { padding: 30rpx; } -.p-4 { padding: 40rpx; } - -.pt-0 { padding-top: 0; } -.pt-1 { padding-top: 10rpx; } -.pt-2 { padding-top: 20rpx; } -.pt-3 { padding-top: 30rpx; } -.pt-4 { padding-top: 40rpx; } - -.pb-0 { padding-bottom: 0; } -.pb-1 { padding-bottom: 10rpx; } -.pb-2 { padding-bottom: 20rpx; } -.pb-3 { padding-bottom: 30rpx; } -.pb-4 { padding-bottom: 40rpx; } - -.pl-0 { padding-left: 0; } -.pl-1 { padding-left: 10rpx; } -.pl-2 { padding-left: 20rpx; } -.pl-3 { padding-left: 30rpx; } -.pl-4 { padding-left: 40rpx; } - -.pr-0 { padding-right: 0; } -.pr-1 { padding-right: 10rpx; } -.pr-2 { padding-right: 20rpx; } -.pr-3 { padding-right: 30rpx; } -.pr-4 { padding-right: 40rpx; } \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.js b/miniprogram-example/pages/login/login.js deleted file mode 100644 index e4ad173..0000000 --- a/miniprogram-example/pages/login/login.js +++ /dev/null @@ -1,48 +0,0 @@ -// login.js -import { wxLogin } from '../../services/auth'; - -Page({ - data: { - loading: false - }, - - onLoad() { - // 页面加载时检查是否已经登录 - const token = wx.getStorageSync('token'); - if (token) { - this.redirectToHome(); - } - }, - - // 处理登录按钮点击 - handleLogin() { - if (this.data.loading) return; - - this.setData({ loading: true }); - - wxLogin() - .then(() => { - wx.showToast({ - title: '登录成功', - icon: 'success' - }); - this.redirectToHome(); - }) - .catch(err => { - wx.showToast({ - title: err.message || '登录失败', - icon: 'none' - }); - }) - .finally(() => { - this.setData({ loading: false }); - }); - }, - - // 跳转到首页 - redirectToHome() { - wx.reLaunch({ - url: '/pages/otp-list/index' - }); - } -}); \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.json b/miniprogram-example/pages/login/login.json deleted file mode 100644 index 8835af0..0000000 --- a/miniprogram-example/pages/login/login.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "usingComponents": {} -} \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.wxml b/miniprogram-example/pages/login/login.wxml deleted file mode 100644 index d73a711..0000000 --- a/miniprogram-example/pages/login/login.wxml +++ /dev/null @@ -1,30 +0,0 @@ - - - - - OTPM 小程序 - - - - \ No newline at end of file diff --git a/miniprogram-example/pages/login/login.wxss b/miniprogram-example/pages/login/login.wxss deleted file mode 100644 index 26cece6..0000000 --- a/miniprogram-example/pages/login/login.wxss +++ /dev/null @@ -1,97 +0,0 @@ -/* login.wxss */ -.container { - display: flex; - flex-direction: column; - align-items: center; - justify-content: space-between; - height: 100vh; - padding: 60rpx 40rpx; - box-sizing: border-box; - background-color: #f8f8f8; -} - -.logo-container { - display: flex; - flex-direction: column; - align-items: center; - margin-top: 80rpx; -} - -.logo { - width: 180rpx; - height: 180rpx; - margin-bottom: 20rpx; -} - -.app-name { - font-size: 36rpx; - font-weight: bold; - color: #333; -} - -.login-container { - width: 100%; - display: flex; - flex-direction: column; - align-items: center; - margin-bottom: 100rpx; -} - -.login-title { - font-size: 48rpx; - font-weight: bold; - color: #333; - margin-bottom: 20rpx; -} - -.login-subtitle { - font-size: 28rpx; - color: #666; - margin-bottom: 80rpx; -} - -.login-button { - width: 80%; - height: 88rpx; - border-radius: 44rpx; - font-size: 32rpx; - display: flex; - align-items: center; - justify-content: center; -} - -.login-button.loading { - background-color: #8cc4ff; -} - -.loading-container { - display: flex; - align-items: center; - justify-content: center; -} - -.loading-icon { - width: 36rpx; - height: 36rpx; - margin-right: 10rpx; - border: 4rpx solid #ffffff; - border-radius: 50%; - border-top-color: transparent; - animation: spin 1s linear infinite; -} - -@keyframes spin { - 0% { transform: rotate(0deg); } - 100% { transform: rotate(360deg); } -} - -.privacy-policy { - margin-top: 40rpx; - font-size: 24rpx; - color: #999; -} - -.policy-link { - color: #1890ff; - display: inline; -} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.js b/miniprogram-example/pages/otp-add/index.js deleted file mode 100644 index e89c623..0000000 --- a/miniprogram-example/pages/otp-add/index.js +++ /dev/null @@ -1,169 +0,0 @@ -// otp-add/index.js -import { createOTP } from '../../services/otp'; - -Page({ - data: { - form: { - name: '', - issuer: '', - secret: '', - algorithm: 'SHA1', - digits: 6, - period: 30 - }, - algorithms: ['SHA1', 'SHA256', 'SHA512'], - digitOptions: [6, 8], - periodOptions: [30, 60], - submitting: false, - scanMode: false - }, - - // 处理输入变化 - handleInputChange(e) { - const { field } = e.currentTarget.dataset; - const { value } = e.detail; - - this.setData({ - [`form.${field}`]: value - }); - }, - - // 处理选择器变化 - handlePickerChange(e) { - const { field } = e.currentTarget.dataset; - const { value } = e.detail; - - const options = this.data[`${field}Options`] || this.data[field]; - const selectedValue = options[value]; - - this.setData({ - [`form.${field}`]: selectedValue - }); - }, - - // 扫描二维码 - handleScanQRCode() { - this.setData({ scanMode: true }); - - wx.scanCode({ - scanType: ['qrCode'], - success: (res) => { - try { - // 解析otpauth://协议的URL - const url = res.result; - if (url.startsWith('otpauth://totp/')) { - const parsedUrl = new URL(url); - const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠 - - // 解析路径中的issuer和name - let issuer = ''; - let name = path; - - if (path.includes(':')) { - const parts = path.split(':'); - issuer = parts[0]; - name = parts[1]; - } - - // 从查询参数中获取其他信息 - const secret = parsedUrl.searchParams.get('secret') || ''; - const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1'; - const digits = parseInt(parsedUrl.searchParams.get('digits') || '6'); - const period = parseInt(parsedUrl.searchParams.get('period') || '30'); - - // 如果查询参数中有issuer,优先使用 - if (parsedUrl.searchParams.get('issuer')) { - issuer = parsedUrl.searchParams.get('issuer'); - } - - this.setData({ - form: { - name, - issuer, - secret, - algorithm, - digits, - period - } - }); - - wx.showToast({ - title: '二维码解析成功', - icon: 'success' - }); - } else { - wx.showToast({ - title: '不支持的二维码格式', - icon: 'none' - }); - } - } catch (err) { - wx.showToast({ - title: '二维码解析失败', - icon: 'none' - }); - } - }, - fail: () => { - wx.showToast({ - title: '扫描取消', - icon: 'none' - }); - }, - complete: () => { - this.setData({ scanMode: false }); - } - }); - }, - - // 提交表单 - handleSubmit() { - const { form } = this.data; - - // 表单验证 - if (!form.name) { - wx.showToast({ - title: '请输入名称', - icon: 'none' - }); - return; - } - - if (!form.secret) { - wx.showToast({ - title: '请输入密钥', - icon: 'none' - }); - return; - } - - this.setData({ submitting: true }); - - createOTP(form) - .then(() => { - wx.showToast({ - title: '添加成功', - icon: 'success' - }); - - // 返回上一页 - setTimeout(() => { - wx.navigateBack(); - }, 1500); - }) - .catch(err => { - wx.showToast({ - title: err.message || '添加失败', - icon: 'none' - }); - }) - .finally(() => { - this.setData({ submitting: false }); - }); - }, - - // 取消 - handleCancel() { - wx.navigateBack(); - } -}); \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.json b/miniprogram-example/pages/otp-add/index.json deleted file mode 100644 index 8835af0..0000000 --- a/miniprogram-example/pages/otp-add/index.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "usingComponents": {} -} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.wxml b/miniprogram-example/pages/otp-add/index.wxml deleted file mode 100644 index e5a2b01..0000000 --- a/miniprogram-example/pages/otp-add/index.wxml +++ /dev/null @@ -1,119 +0,0 @@ - - - - 添加OTP - - - - - 名称 * - - - - - 发行方 - - - - - 密钥 * - - - - 🔍 - - - - - - - - - 算法 - - - {{form.algorithm}} - - - - - - - - 位数 - - - {{form.digits}} - - - - - - - 周期(秒) - - - {{form.period}} - - - - - - - - - - - - - \ No newline at end of file diff --git a/miniprogram-example/pages/otp-add/index.wxss b/miniprogram-example/pages/otp-add/index.wxss deleted file mode 100644 index 3906a97..0000000 --- a/miniprogram-example/pages/otp-add/index.wxss +++ /dev/null @@ -1,176 +0,0 @@ -/* otp-add/index.wxss */ -.container { - min-height: 100vh; - background-color: #f8f8f8; - padding: 0 0 40rpx 0; -} - -.header { - display: flex; - justify-content: space-between; - align-items: center; - padding: 40rpx 32rpx; - background-color: #ffffff; - position: sticky; - top: 0; - z-index: 100; - box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05); -} - -.title { - font-size: 36rpx; - font-weight: bold; - color: #333333; -} - -.form-container { - background-color: #ffffff; - padding: 32rpx; - margin: 32rpx; - border-radius: 16rpx; - box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); -} - -.form-group { - margin-bottom: 32rpx; -} - -.form-row { - display: flex; - justify-content: space-between; -} - -.form-group.half { - width: 48%; -} - -.form-label { - display: block; - font-size: 28rpx; - color: #666666; - margin-bottom: 12rpx; -} - -.required { - color: #ff4d4f; -} - -.form-input { - width: 100%; - height: 80rpx; - background-color: #f5f5f5; - border-radius: 8rpx; - padding: 0 24rpx; - font-size: 28rpx; - color: #333333; - box-sizing: border-box; -} - -.secret-input-container { - position: relative; -} - -.scan-button { - position: absolute; - right: 20rpx; - top: 50%; - transform: translateY(-50%); - width: 60rpx; - height: 60rpx; - display: flex; - align-items: center; - justify-content: center; -} - -.scan-icon { - font-size: 40rpx; - color: #1890ff; -} - -.scanning-indicator { - position: absolute; - right: 20rpx; - top: 50%; - transform: translateY(-50%); - width: 40rpx; - height: 40rpx; -} - -.scanning-spinner { - width: 40rpx; - height: 40rpx; - border: 4rpx solid #f3f3f3; - border-top: 4rpx solid #1890ff; - border-radius: 50%; - animation: spin 1s linear infinite; -} - -@keyframes spin { - 0% { transform: rotate(0deg); } - 100% { transform: rotate(360deg); } -} - -.picker-view { - width: 100%; - height: 80rpx; - background-color: #f5f5f5; - border-radius: 8rpx; - padding: 0 24rpx; - font-size: 28rpx; - color: #333333; - display: flex; - align-items: center; - justify-content: space-between; - box-sizing: border-box; -} - -.picker-arrow { - font-size: 24rpx; - color: #999999; -} - -.button-group { - display: flex; - justify-content: space-between; - padding: 32rpx; -} - -.cancel-button, .submit-button { - width: 48%; - height: 88rpx; - border-radius: 44rpx; - font-size: 32rpx; - display: flex; - align-items: center; - justify-content: center; -} - -.cancel-button { - background-color: #f5f5f5; - color: #666666; -} - -.submit-button { - background-color: #1890ff; - color: #ffffff; -} - -.submit-button.loading { - background-color: #8cc4ff; -} - -.loading-container { - display: flex; - align-items: center; - justify-content: center; -} - -.loading-icon { - width: 36rpx; - height: 36rpx; - margin-right: 10rpx; - border: 4rpx solid #ffffff; - border-radius: 50%; - border-top-color: transparent; - animation: spin 1s linear infinite; -} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.js b/miniprogram-example/pages/otp-list/index.js deleted file mode 100644 index 9ed6fa1..0000000 --- a/miniprogram-example/pages/otp-list/index.js +++ /dev/null @@ -1,213 +0,0 @@ -// otp-list/index.js -import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp'; -import { checkLoginStatus } from '../../services/auth'; - -Page({ - data: { - otpList: [], - loading: true, - refreshing: false - }, - - onLoad() { - this.checkLogin(); - }, - - onShow() { - // 每次页面显示时刷新OTP列表 - if (!this.data.loading) { - this.fetchOTPList(); - } - }, - - // 下拉刷新 - onPullDownRefresh() { - this.setData({ refreshing: true }); - this.fetchOTPList().finally(() => { - wx.stopPullDownRefresh(); - this.setData({ refreshing: false }); - }); - }, - - // 检查登录状态 - checkLogin() { - checkLoginStatus().then(isLoggedIn => { - if (isLoggedIn) { - this.fetchOTPList(); - } else { - wx.redirectTo({ - url: '/pages/login/login' - }); - } - }); - }, - - // 获取OTP列表 - fetchOTPList() { - this.setData({ loading: true }); - return getOTPList() - .then(res => { - if (res.data && Array.isArray(res.data)) { - this.setData({ - otpList: res.data, - loading: false - }); - - // 获取每个OTP的当前验证码 - this.refreshOTPCodes(); - } - }) - .catch(err => { - wx.showToast({ - title: '获取OTP列表失败', - icon: 'none' - }); - this.setData({ loading: false }); - }); - }, - - // 刷新所有OTP的验证码 - refreshOTPCodes() { - const { otpList } = this.data; - - // 为每个OTP获取当前验证码 - const promises = otpList.map(otp => { - return getOTPCode(otp.id) - .then(res => { - if (res.data && res.data.code) { - return { - id: otp.id, - code: res.data.code, - expiresIn: res.data.expires_in || 30 - }; - } - return null; - }) - .catch(() => null); - }); - - Promise.all(promises).then(results => { - const updatedList = [...this.data.otpList]; - - results.forEach(result => { - if (result) { - const index = updatedList.findIndex(otp => otp.id === result.id); - if (index !== -1) { - updatedList[index] = { - ...updatedList[index], - currentCode: result.code, - expiresIn: result.expiresIn - }; - } - } - }); - - this.setData({ otpList: updatedList }); - - // 设置定时器,每秒更新倒计时 - this.startCountdown(); - }); - }, - - // 开始倒计时 - startCountdown() { - // 清除之前的定时器 - if (this.countdownTimer) { - clearInterval(this.countdownTimer); - } - - // 创建新的定时器,每秒更新一次 - this.countdownTimer = setInterval(() => { - const { otpList } = this.data; - let needRefresh = false; - - const updatedList = otpList.map(otp => { - if (!otp.countdown) { - otp.countdown = otp.expiresIn || 30; - } - - otp.countdown -= 1; - - // 如果倒计时结束,标记需要刷新 - if (otp.countdown <= 0) { - needRefresh = true; - } - - return otp; - }); - - this.setData({ otpList: updatedList }); - - // 如果有OTP需要刷新,重新获取验证码 - if (needRefresh) { - this.refreshOTPCodes(); - } - }, 1000); - }, - - // 添加新的OTP - handleAddOTP() { - wx.navigateTo({ - url: '/pages/otp-add/index' - }); - }, - - // 编辑OTP - handleEditOTP(e) { - const { id } = e.currentTarget.dataset; - wx.navigateTo({ - url: `/pages/otp-edit/index?id=${id}` - }); - }, - - // 删除OTP - handleDeleteOTP(e) { - const { id, name } = e.currentTarget.dataset; - - wx.showModal({ - title: '确认删除', - content: `确定要删除 ${name} 吗?`, - confirmColor: '#ff4d4f', - success: (res) => { - if (res.confirm) { - deleteOTP(id) - .then(() => { - wx.showToast({ - title: '删除成功', - icon: 'success' - }); - this.fetchOTPList(); - }) - .catch(err => { - wx.showToast({ - title: '删除失败', - icon: 'none' - }); - }); - } - } - }); - }, - - // 复制验证码 - handleCopyCode(e) { - const { code } = e.currentTarget.dataset; - - wx.setClipboardData({ - data: code, - success: () => { - wx.showToast({ - title: '验证码已复制', - icon: 'success' - }); - } - }); - }, - - onUnload() { - // 页面卸载时清除定时器 - if (this.countdownTimer) { - clearInterval(this.countdownTimer); - } - } -}); \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.json b/miniprogram-example/pages/otp-list/index.json deleted file mode 100644 index 8835af0..0000000 --- a/miniprogram-example/pages/otp-list/index.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "usingComponents": {} -} \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.wxml b/miniprogram-example/pages/otp-list/index.wxml deleted file mode 100644 index 27cb5ea..0000000 --- a/miniprogram-example/pages/otp-list/index.wxml +++ /dev/null @@ -1,59 +0,0 @@ - - - - 我的OTP列表 - - + - - - - - - - 加载中... - - - - - - - - - {{item.name}} - {{item.issuer}} - - - - {{item.currentCode || '******'}} - 点击复制 - - - - - {{item.countdown || 0}}s - - - - - - - - - - - - - - - - - - 暂无OTP,点击右上角添加 - - - \ No newline at end of file diff --git a/miniprogram-example/pages/otp-list/index.wxss b/miniprogram-example/pages/otp-list/index.wxss deleted file mode 100644 index 40436d5..0000000 --- a/miniprogram-example/pages/otp-list/index.wxss +++ /dev/null @@ -1,201 +0,0 @@ -/* otp-list/index.wxss */ -.container { - min-height: 100vh; - background-color: #f8f8f8; - padding: 0 0 40rpx 0; -} - -.header { - display: flex; - justify-content: space-between; - align-items: center; - padding: 40rpx 32rpx; - background-color: #ffffff; - position: sticky; - top: 0; - z-index: 100; - box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05); -} - -.title { - font-size: 36rpx; - font-weight: bold; - color: #333333; -} - -.add-button { - width: 64rpx; - height: 64rpx; - border-radius: 32rpx; - background-color: #1890ff; - display: flex; - align-items: center; - justify-content: center; - box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3); -} - -.add-icon { - color: #ffffff; - font-size: 40rpx; - line-height: 1; -} - -/* 加载状态 */ -.loading-container { - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; - padding: 120rpx 0; -} - -.loading-spinner { - width: 64rpx; - height: 64rpx; - border: 6rpx solid #f3f3f3; - border-top: 6rpx solid #1890ff; - border-radius: 50%; - animation: spin 1s linear infinite; -} - -.loading-text { - margin-top: 20rpx; - font-size: 28rpx; - color: #999999; -} - -@keyframes spin { - 0% { transform: rotate(0deg); } - 100% { transform: rotate(360deg); } -} - -/* OTP列表 */ -.otp-list { - padding: 20rpx 32rpx; -} - -.otp-item { - background-color: #ffffff; - border-radius: 16rpx; - padding: 32rpx; - margin-bottom: 20rpx; - display: flex; - justify-content: space-between; - box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05); -} - -.otp-info { - flex: 1; - margin-right: 20rpx; -} - -.otp-name-row { - display: flex; - align-items: center; - margin-bottom: 16rpx; -} - -.otp-name { - font-size: 32rpx; - font-weight: bold; - color: #333333; - margin-right: 16rpx; -} - -.otp-issuer { - font-size: 24rpx; - color: #666666; - background-color: #f5f5f5; - padding: 4rpx 12rpx; - border-radius: 8rpx; -} - -.otp-code-row { - display: flex; - align-items: center; - margin-bottom: 20rpx; -} - -.otp-code { - font-size: 44rpx; - font-family: monospace; - font-weight: bold; - color: #1890ff; - letter-spacing: 4rpx; - margin-right: 16rpx; -} - -.copy-hint { - font-size: 24rpx; - color: #999999; -} - -.otp-countdown { - position: relative; - width: 100%; -} - -.countdown-text { - position: absolute; - right: 0; - top: -30rpx; - font-size: 24rpx; - color: #999999; -} - -/* OTP操作按钮 */ -.otp-actions { - display: flex; - flex-direction: column; - justify-content: space-between; -} - -.action-button { - width: 56rpx; - height: 56rpx; - border-radius: 28rpx; - display: flex; - align-items: center; - justify-content: center; - margin-bottom: 16rpx; -} - -.action-button.edit { - background-color: #f0f7ff; -} - -.action-button.delete { - background-color: #fff1f0; -} - -.action-icon { - font-size: 32rpx; -} - -.edit .action-icon { - color: #1890ff; -} - -.delete .action-icon { - color: #ff4d4f; -} - -/* 空状态 */ -.empty-state { - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; - padding: 120rpx 0; -} - -.empty-image { - width: 240rpx; - height: 240rpx; - margin-bottom: 32rpx; -} - -.empty-text { - font-size: 28rpx; - color: #999999; -} \ No newline at end of file diff --git a/miniprogram-example/project.config.json b/miniprogram-example/project.config.json deleted file mode 100644 index 3fb79ad..0000000 --- a/miniprogram-example/project.config.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "description": "项目配置文件", - "packOptions": { - "ignore": [], - "include": [] - }, - "miniprogramRoot": "", - "compileType": "miniprogram", - "projectname": "OTPM", - "setting": { - "useCompilerPlugins": [ - "sass" - ], - "babelSetting": { - "ignore": [], - "disablePlugins": [], - "outputPath": "" - }, - "es6": true, - "enhance": true, - "minified": true, - "postcss": true, - "minifyWXSS": true, - "minifyWXML": true, - "uglifyFileName": true, - "packNpmManually": false, - "packNpmRelationList": [], - "ignoreUploadUnusedFiles": true, - "compileWorklet": false, - "uploadWithSourceMap": true, - "localPlugins": false, - "disableUseStrict": false, - "condition": false, - "swc": false, - "disableSWC": true - }, - "simulatorType": "wechat", - "simulatorPluginLibVersion": {}, - "condition": {}, - "srcMiniprogramRoot": "", - "appid": "wxb6599459668b6b55", - "libVersion": "2.30.2", - "editorSetting": { - "tabIndent": "insertSpaces", - "tabSize": 2 - } -} \ No newline at end of file diff --git a/miniprogram-example/project.private.config.json b/miniprogram-example/project.private.config.json deleted file mode 100644 index 6b2a738..0000000 --- a/miniprogram-example/project.private.config.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "libVersion": "3.8.5", - "projectname": "OTPM", - "condition": {}, - "setting": { - "urlCheck": true, - "coverView": false, - "lazyloadPlaceholderEnable": false, - "skylineRenderEnable": false, - "preloadBackgroundData": false, - "autoAudits": false, - "useApiHook": true, - "useApiHostProcess": true, - "showShadowRootInWxmlPanel": false, - "useStaticServer": false, - "useLanDebug": false, - "showES6CompileOption": false, - "compileHotReLoad": true, - "checkInvalidKey": true, - "ignoreDevUnusedFiles": true, - "bigPackageSizeSupport": false - } -} \ No newline at end of file diff --git a/miniprogram-example/services/auth.js b/miniprogram-example/services/auth.js deleted file mode 100644 index 4c95fa4..0000000 --- a/miniprogram-example/services/auth.js +++ /dev/null @@ -1,84 +0,0 @@ -// auth.js - 认证相关服务 - -import request from '../utils/request'; - -/** - * 微信登录 - * 1. 调用wx.login获取code - * 2. 发送code到服务端换取token和openid - * 3. 保存token和openid到本地存储 - */ -export const wxLogin = () => { - return new Promise((resolve, reject) => { - wx.login({ - success: (res) => { - if (res.code) { - // 发送code到服务端 - request({ - url: '/login', - method: 'POST', - data: { - code: res.code - } - }).then(response => { - // 保存token和openid - if (response.data && response.data.token && response.data.openid) { - wx.setStorageSync('token', response.data.token); - wx.setStorageSync('openid', response.data.openid); - resolve(response.data); - } else { - reject(new Error('登录失败,服务器返回数据格式错误')); - } - }).catch(err => { - reject(err); - }); - } else { - reject(new Error('登录失败,获取code失败: ' + res.errMsg)); - } - }, - fail: (err) => { - reject(new Error('微信登录失败: ' + err.errMsg)); - } - }); - }); -}; - -/** - * 检查登录状态 - * 1. 检查本地是否有token和openid - * 2. 如果有,验证token是否有效 - * 3. 如果无效,清除本地存储并返回false - */ -export const checkLoginStatus = () => { - return new Promise((resolve, reject) => { - const token = wx.getStorageSync('token'); - const openid = wx.getStorageSync('openid'); - - if (!token || !openid) { - resolve(false); - return; - } - - // 验证token有效性 - request({ - url: '/verify-token', - method: 'POST' - }).then(() => { - resolve(true); - }).catch(() => { - // token无效,清除本地存储 - wx.removeStorageSync('token'); - wx.removeStorageSync('openid'); - resolve(false); - }); - }); -}; - -/** - * 退出登录 - */ -export const logout = () => { - wx.removeStorageSync('token'); - wx.removeStorageSync('openid'); - return Promise.resolve(); -}; \ No newline at end of file diff --git a/miniprogram-example/services/otp.js b/miniprogram-example/services/otp.js deleted file mode 100644 index 9da9925..0000000 --- a/miniprogram-example/services/otp.js +++ /dev/null @@ -1,119 +0,0 @@ -// otp.js - OTP相关服务 - -import request from '../utils/request'; - -/** - * 创建新的OTP - * @param {Object} params - 创建OTP的参数 - * @param {string} params.name - OTP名称 - * @param {string} params.issuer - 发行方 - * @param {string} params.secret - 密钥 - * @param {string} params.algorithm - 算法,默认为SHA1 - * @param {number} params.digits - 位数,默认为6 - * @param {number} params.period - 周期,默认为30秒 - * @returns {Promise} - 返回创建结果 - */ -export const createOTP = (params) => { - if (!params || !params.secret) { - return Promise.reject(new Error('缺少必要的参数: secret')); - } - return request({ - url: '/otp', - method: 'POST', - data: { - name: params.name || '', - issuer: params.issuer || '', - secret: params.secret, - algorithm: params.algorithm || 'SHA1', - digits: params.digits || 6, - period: params.period || 30 - } - }).catch(err => { - console.error('创建OTP失败:', err); - throw new Error('创建OTP失败: ' + (err.message || '未知错误')); - }); -}; - -/** - * 获取用户所有OTP列表 - * @returns {Promise} - 返回OTP列表 - */ -export const getOTPList = () => { - return request({ - url: '/otp', - method: 'GET' - }).catch(err => { - console.error('获取OTP列表失败:', err); - throw new Error('获取OTP列表失败: ' + (err.message || '未知错误')); - }); -}; - -/** - * 获取指定OTP的当前验证码 - * @param {string} id - OTP的ID - * @returns {Promise} - 返回当前验证码 - */ -export const getOTPCode = (id) => { - if (!id) { - return Promise.reject(new Error('缺少必要的参数: id')); - } - return request({ - url: `/otp/${id}/code`, - method: 'GET' - }).catch(err => { - console.error('获取OTP代码失败:', err); - throw new Error('获取OTP代码失败: ' + (err.message || '未知错误')); - }); -}; - -/** - * 验证OTP - * @param {string} id - OTP的ID - * @param {string} code - 用户输入的验证码 - * @returns {Promise} - 返回验证结果 - */ -export const verifyOTP = (id, code) => { - if (!id || !code) { - return Promise.reject(new Error('缺少必要的参数: id或code')); - } - return request({ - url: `/otp/${id}/verify`, - method: 'POST', - data: { code } - }).catch(err => { - console.error('验证OTP失败:', err); - throw new Error('验证OTP失败: ' + (err.message || '未知错误')); - }); -}; - -/** - * 更新OTP信息 - * @param {string} id - OTP的ID - * @param {Object} params - 更新的参数 - * @returns {Promise} - 返回更新结果 - */ -export const updateOTP = (id, params) => { - if (!id || !params) { - return Promise.reject(new Error('缺少必要的参数: id或params')); - } - return request({ - url: `/otp/${id}`, - method: 'PUT', - data: params - }).catch(err => { - console.error('更新OTP失败:', err); - throw new Error('更新OTP失败: ' + (err.message || '未知错误')); - }); -}; - -/** - * 删除OTP - * @param {string} id - OTP的ID - * @returns {Promise} - 返回删除结果 - */ -export const deleteOTP = (id) => { - return request({ - url: `/otp/${id}`, - method: 'DELETE' - }); -}; \ No newline at end of file diff --git a/miniprogram-example/utils/request.js b/miniprogram-example/utils/request.js deleted file mode 100644 index 9ea7e2e..0000000 --- a/miniprogram-example/utils/request.js +++ /dev/null @@ -1,58 +0,0 @@ -// request.js - 网络请求工具类 - -const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名 - -// 请求拦截器 -const request = (options) => { - return new Promise((resolve, reject) => { - const token = wx.getStorageSync('token'); - const header = { - 'Content-Type': 'application/json', - ...options.header - }; - - // 如果有token,添加到请求头 - if (token) { - header['Authorization'] = `Bearer ${token}`; - } - - wx.request({ - url: `${BASE_URL}${options.url}`, - method: options.method || 'GET', - data: options.data, - header: header, - success: (res) => { - // 处理业务错误 - if (res.data.code !== 0) { - // token过期,直接清除并跳转登录 - if (res.statusCode === 401) { - wx.removeStorageSync('token'); - wx.removeStorageSync('openid'); - reject(new Error('登录已过期,请重新登录')); - return; - } - reject(new Error(res.data.message || '请求失败')); - return; - } - resolve(res.data); - }, - fail: reject - }); - }); -}; - -// 刷新token -const refreshToken = () => { - return request({ - url: '/refresh-token', - method: 'POST' - }).then(res => { - if (res.data && res.data.token) { - wx.setStorageSync('token', res.data.token); - return res.data.token; - } - throw new Error('Failed to refresh token'); - }); -}; - -export default request; \ No newline at end of file diff --git a/models/otp.go b/models/otp.go deleted file mode 100644 index 8eaab88..0000000 --- a/models/otp.go +++ /dev/null @@ -1,66 +0,0 @@ -package models - -import ( - "context" - "time" -) - -// OTP represents a TOTP configuration -type OTP struct { - ID int64 `json:"id" db:"id"` - UserID string `json:"user_id" db:"user_id" validate:"required"` - OpenID string `json:"openid" db:"openid" validate:"required"` - Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"` - Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"` - Secret string `json:"secret" db:"secret" validate:"required,otpsecret"` - Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` - Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"` - Period int `json:"period" db:"period" validate:"required,min=30,max=60"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` -} - -// OTPParams represents common OTP parameters used in creation and update -type OTPParams struct { - Name string `json:"name" validate:"required,min=1,max=100,no_xss"` - Issuer string `json:"issuer" validate:"omitempty,issuer"` - Secret string `json:"secret" validate:"required,otpsecret"` - Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` - Digits int `json:"digits" validate:"omitempty,min=6,max=8"` - Period int `json:"period" validate:"omitempty,min=30,max=60"` -} - -// OTPRepository handles OTP data storage -type OTPRepository struct { - // Add your database connection or ORM here -} - -// Create creates a new OTP record -func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error { - // Implement database creation logic - return nil -} - -// FindByID finds an OTP by ID and user ID -func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) { - // Implement database lookup logic - return nil, nil -} - -// FindAllByUserID finds all OTPs for a user -func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) { - // Implement database query logic - return nil, nil -} - -// Update updates an existing OTP record -func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error { - // Implement database update logic - return nil -} - -// Delete deletes an OTP record -func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error { - // Implement database deletion logic - return nil -} diff --git a/models/user.go b/models/user.go deleted file mode 100644 index f12399b..0000000 --- a/models/user.go +++ /dev/null @@ -1,114 +0,0 @@ -package models - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/jmoiron/sqlx" -) - -// User represents a user in the system -type User struct { - ID string `db:"id" json:"id"` - OpenID string `db:"openid" json:"openid"` - SessionKey string `db:"session_key" json:"-"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` -} - -// UserRepository handles user data operations -type UserRepository struct { - db *sqlx.DB -} - -// NewUserRepository creates a new UserRepository -func NewUserRepository(db *sqlx.DB) *UserRepository { - return &UserRepository{db: db} -} - -// FindByID finds a user by ID -func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) { - var user User - query := `SELECT * FROM users WHERE id = ?` - err := r.db.GetContext(ctx, &user, query, id) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("user not found: %w", err) - } - return nil, fmt.Errorf("failed to find user: %w", err) - } - return &user, nil -} - -// FindByOpenID finds a user by OpenID -func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) { - var user User - query := `SELECT * FROM users WHERE openid = ?` - err := r.db.GetContext(ctx, &user, query, openID) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil // User not found, but not an error - } - return nil, fmt.Errorf("failed to find user: %w", err) - } - return &user, nil -} - -// Create creates a new user -func (r *UserRepository) Create(ctx context.Context, user *User) error { - query := ` - INSERT INTO users (id, openid, session_key, created_at, updated_at) - VALUES (?, ?, ?, ?, ?) - ` - now := time.Now() - user.CreatedAt = now - user.UpdatedAt = now - - _, err := r.db.ExecContext( - ctx, - query, - user.ID, - user.OpenID, - user.SessionKey, - user.CreatedAt, - user.UpdatedAt, - ) - if err != nil { - return fmt.Errorf("failed to create user: %w", err) - } - return nil -} - -// Update updates an existing user -func (r *UserRepository) Update(ctx context.Context, user *User) error { - query := ` - UPDATE users - SET session_key = ?, updated_at = ? - WHERE id = ? - ` - user.UpdatedAt = time.Now() - - _, err := r.db.ExecContext( - ctx, - query, - user.SessionKey, - user.UpdatedAt, - user.ID, - ) - if err != nil { - return fmt.Errorf("failed to update user: %w", err) - } - return nil -} - -// Delete deletes a user -func (r *UserRepository) Delete(ctx context.Context, id string) error { - query := `DELETE FROM users WHERE id = ?` - _, err := r.db.ExecContext(ctx, query, id) - if err != nil { - return fmt.Errorf("failed to delete user: %w", err) - } - return nil -} diff --git a/otp_api.go b/otp_api.go new file mode 100644 index 0000000..573ad69 --- /dev/null +++ b/otp_api.go @@ -0,0 +1,417 @@ +package main + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + + _ "github.com/lib/pq" +) + +// 加密密钥(32字节AES-256) +var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取 + +// encryptTokenSecret 加密令牌密钥 +func encryptTokenSecret(secret string) (string, error) { + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", err + } + + ciphertext := make([]byte, aes.BlockSize+len(secret)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", err + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret)) + + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// decryptTokenSecret 解密令牌密钥 +func decryptTokenSecret(encrypted string) (string, error) { + ciphertext, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", err + } + + if len(ciphertext) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(ciphertext, ciphertext) + + return string(ciphertext), nil +} + +// SaveRequest 保存请求的数据结构 +type SaveRequest struct { + Tokens []TokenData `json:"tokens"` + UserID string `json:"userId"` + Timestamp int64 `json:"timestamp"` +} + +// TokenData token数据结构 +type TokenData struct { + ID string `json:"id"` + Issuer string `json:"issuer"` + Account string `json:"account"` + Secret string `json:"secret"` + Type string `json:"type"` + Counter int `json:"counter,omitempty"` + Period int `json:"period"` + Digits int `json:"digits"` + Algo string `json:"algo"` +} + +// SaveResponse 保存响应的数据结构 +type SaveResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data struct { + ID string `json:"id"` + } `json:"data"` +} + +// RecoverResponse 恢复响应的数据结构 +type RecoverResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data struct { + Tokens []TokenData `json:"tokens"` + Timestamp int64 `json:"timestamp"` + } `json:"data"` +} + +var db *sql.DB + +// InitDB 初始化数据库连接 +func InitDB() error { + connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable" + var err error + db, err = sql.Open("postgres", connStr) + if err != nil { + return fmt.Errorf("error opening database: %v", err) + } + + if err = db.Ping(); err != nil { + return fmt.Errorf("error connecting to the database: %v", err) + } + + return nil +} + +// SaveHandler 保存token的接口处理函数 +func SaveHandler(w http.ResponseWriter, r *http.Request) { + // 设置CORS头 + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // 处理OPTIONS请求 + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + // 检查请求方法 + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 解析请求 + var req SaveRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("Error decoding request: %v", err) + sendErrorResponse(w, "Invalid request body", http.StatusBadRequest) + return + } + + // 验证请求数据 + if req.UserID == "" { + sendErrorResponse(w, "Missing user ID", http.StatusBadRequest) + return + } + + if len(req.Tokens) == 0 { + sendErrorResponse(w, "No tokens provided", http.StatusBadRequest) + return + } + + // 开始数据库事务 + tx, err := db.Begin() + if err != nil { + log.Printf("Error starting transaction: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + defer tx.Rollback() + + // 删除用户现有的tokens + _, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID) + if err != nil { + log.Printf("Error deleting existing tokens: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + + // 插入新的tokens + stmt, err := tx.Prepare(` + INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + `) + if err != nil { + log.Printf("Error preparing statement: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + defer stmt.Close() + + for _, token := range req.Tokens { + // 加密secret + encryptedSecret, err := encryptTokenSecret(token.Secret) + if err != nil { + log.Printf("Error encrypting token secret: %v", err) + sendErrorResponse(w, "Encryption error", http.StatusInternalServerError) + return + } + + _, err = stmt.Exec( + token.ID, + req.UserID, + token.Issuer, + token.Account, + encryptedSecret, + token.Type, + token.Counter, + token.Period, + token.Digits, + token.Algo, + req.Timestamp, + ) + if err != nil { + log.Printf("Error inserting token: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + } + + // 提交事务 + if err = tx.Commit(); err != nil { + log.Printf("Error committing transaction: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + + // 返回成功响应 + resp := SaveResponse{ + Success: true, + Message: "Tokens saved successfully", + } + resp.Data.ID = req.UserID + + sendJSONResponse(w, resp, http.StatusOK) +} + +// RecoverHandler 恢复token的接口处理函数 +func RecoverHandler(w http.ResponseWriter, r *http.Request) { + // 设置CORS头 + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // 处理OPTIONS请求 + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + // 检查请求方法 + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 解析请求体 + var req struct { + UserID string `json:"userId"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("Error decoding request: %v", err) + sendErrorResponse(w, "Invalid request body", http.StatusBadRequest) + return + } + + // 验证用户ID + if req.UserID == "" { + sendErrorResponse(w, "Missing user ID", http.StatusBadRequest) + return + } + + // 查询数据库 + rows, err := db.Query(` + SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp + FROM tokens + WHERE user_id = $1 + ORDER BY timestamp DESC + `, req.UserID) + if err != nil { + log.Printf("Error querying database: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + defer rows.Close() + + // 读取查询结果 + var tokens []TokenData + var timestamp int64 + for rows.Next() { + var token TokenData + var encryptedSecret string + err := rows.Scan( + &token.ID, + &token.Issuer, + &token.Account, + &encryptedSecret, + &token.Type, + &token.Counter, + &token.Period, + &token.Digits, + &token.Algo, + ×tamp, + ) + if err != nil { + log.Printf("Error scanning row: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + + // 解密secret + token.Secret, err = decryptTokenSecret(encryptedSecret) + if err != nil { + log.Printf("Error decrypting token secret: %v", err) + sendErrorResponse(w, "Decryption error", http.StatusInternalServerError) + return + } + tokens = append(tokens, token) + } + + if err = rows.Err(); err != nil { + log.Printf("Error iterating rows: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + + // 返回响应 + resp := RecoverResponse{ + Success: true, + Message: "Tokens recovered successfully", + } + resp.Data.Tokens = tokens + resp.Data.Timestamp = timestamp + + sendJSONResponse(w, resp, http.StatusOK) +} + +// sendErrorResponse 发送错误响应 +func sendErrorResponse(w http.ResponseWriter, message string, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "message": message, + }) +} + +// DeleteTokenHandler 删除单个token的接口处理函数 +func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) { + // 设置CORS头 + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // 处理OPTIONS请求 + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + // 检查请求方法 + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 解析请求 + var req struct { + UserID string `json:"userId"` + TokenID string `json:"tokenId"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("Error decoding request: %v", err) + sendErrorResponse(w, "Invalid request body", http.StatusBadRequest) + return + } + + // 验证请求数据 + if req.UserID == "" || req.TokenID == "" { + sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest) + return + } + + // 执行删除操作 + result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID) + if err != nil { + log.Printf("Error deleting token: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + + // 检查是否真的删除了记录 + rowsAffected, err := result.RowsAffected() + if err != nil { + log.Printf("Error getting rows affected: %v", err) + sendErrorResponse(w, "Database error", http.StatusInternalServerError) + return + } + + if rowsAffected == 0 { + sendErrorResponse(w, "Token not found", http.StatusNotFound) + return + } + + // 返回成功响应 + sendJSONResponse(w, map[string]interface{}{ + "success": true, + "message": "Token deleted successfully", + }, http.StatusOK) +} + +// sendJSONResponse 发送JSON响应 +func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Printf("Error encoding response: %v", err) + sendErrorResponse(w, "Internal server error", http.StatusInternalServerError) + } +} diff --git a/otp_api_test.go b/otp_api_test.go new file mode 100644 index 0000000..7c366d9 --- /dev/null +++ b/otp_api_test.go @@ -0,0 +1,111 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestSaveHandler(t *testing.T) { + // 创建测试服务器 + srv := httptest.NewServer(http.HandlerFunc(SaveHandler)) + defer srv.Close() + + // 准备测试数据 + testData := SaveRequest{ + UserID: "test_user_123", + Tokens: []TokenData{ + { + Issuer: "TestOrg", + Account: "user@test.com", + Secret: "JBSWY3DPEHPK3PXP", + Type: "totp", + Period: 30, + Digits: 6, + Algo: "SHA1", + }, + }, + } + + // 序列化请求体 + body, _ := json.Marshal(testData) + + // 发送请求 + resp, err := http.Post(srv.URL, "application/json", bytes.NewBuffer(body)) + if err != nil { + t.Fatalf("Error making request to server: %v\n", err) + } + defer resp.Body.Close() + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, resp.StatusCode) + } + + // 解析响应 + var saveResp SaveResponse + err = json.NewDecoder(resp.Body).Decode(&saveResp) + if err != nil { + t.Errorf("Error decoding response: %v\n", err) + } + + // 验证响应数据 + if !saveResp.Success { + t.Errorf("Expected success to be true, got false\n") + } + if saveResp.Message != "Tokens saved successfully" { + t.Errorf("Expected message to be 'Tokens saved successfully', got '%s'\n", saveResp.Message) + } +} + +func TestRecoverHandler(t *testing.T) { + // 创建测试服务器 + srv := httptest.NewServer(http.HandlerFunc(RecoverHandler)) + defer srv.Close() + + // 发送请求(没有user_id参数) + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatalf("Error making request to server: %v\n", err) + } + defer resp.Body.Close() + + // 检查响应状态码(应该返回错误) + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status: %d, got: %d\n", http.StatusBadRequest, resp.StatusCode) + } + + // 发送带user_id的请求 + urlWithID := fmt.Sprintf("%s?user_id=test_user_123", srv.URL) + respWithID, err := http.Get(urlWithID) + if err != nil { + t.Fatalf("Error making request to server: %v\n", err) + } + defer respWithID.Body.Close() + + // 检查响应状态码 + if respWithID.StatusCode != http.StatusOK { + t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, respWithID.StatusCode) + } + + // 解析响应 + var recoverResp RecoverResponse + err = json.NewDecoder(respWithID.Body).Decode(&recoverResp) + if err != nil { + t.Errorf("Error decoding response: %v\n", err) + } + + // 验证响应数据 + if !recoverResp.Success { + t.Errorf("Expected success to be true, got false\n") + } + if recoverResp.Message != "Tokens recovered successfully" { + t.Errorf("Expected message to be 'Tokens recovered successfully', got '%s'\n", recoverResp.Message) + } + if len(recoverResp.Tokens) != 1 { + t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Tokens)) + } +} diff --git a/security/security.go b/security/security.go deleted file mode 100644 index 7b69b81..0000000 --- a/security/security.go +++ /dev/null @@ -1,332 +0,0 @@ -package security - -import ( - "context" - "crypto/rand" - "crypto/subtle" - "encoding/base64" - "fmt" - "net/http" - "strings" - "time" - - "golang.org/x/crypto/argon2" -) - -// SecurityService provides security functionality -type SecurityService struct { - config *Config -} - -// Config represents security configuration -type Config struct { - // CSRF protection - CSRFTokenLength int - CSRFTokenExpiry time.Duration - CSRFCookieName string - CSRFHeaderName string - CSRFCookieSecure bool - CSRFCookieHTTPOnly bool - CSRFCookieSameSite http.SameSite - - // Rate limiting - RateLimitRequests int - RateLimitWindow time.Duration - - // Password hashing - Argon2Time uint32 - Argon2Memory uint32 - Argon2Threads uint8 - Argon2KeyLen uint32 - Argon2SaltLen uint32 -} - -// DefaultConfig returns the default security configuration -func DefaultConfig() *Config { - return &Config{ - // CSRF protection - CSRFTokenLength: 32, - CSRFTokenExpiry: 24 * time.Hour, - CSRFCookieName: "csrf_token", - CSRFHeaderName: "X-CSRF-Token", - CSRFCookieSecure: true, - CSRFCookieHTTPOnly: true, - CSRFCookieSameSite: http.SameSiteStrictMode, - - // Rate limiting - RateLimitRequests: 100, - RateLimitWindow: time.Minute, - - // Password hashing - Argon2Time: 1, - Argon2Memory: 64 * 1024, - Argon2Threads: 4, - Argon2KeyLen: 32, - Argon2SaltLen: 16, - } -} - -// NewSecurityService creates a new SecurityService -func NewSecurityService(config *Config) *SecurityService { - if config == nil { - config = DefaultConfig() - } - return &SecurityService{ - config: config, - } -} - -// GenerateCSRFToken generates a CSRF token -func (s *SecurityService) GenerateCSRFToken() (string, error) { - // Generate random bytes - bytes := make([]byte, s.config.CSRFTokenLength) - if _, err := rand.Read(bytes); err != nil { - return "", fmt.Errorf("failed to generate random bytes: %w", err) - } - - // Encode as base64 - token := base64.StdEncoding.EncodeToString(bytes) - return token, nil -} - -// SetCSRFCookie sets a CSRF cookie -func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) { - http.SetCookie(w, &http.Cookie{ - Name: s.config.CSRFCookieName, - Value: token, - Path: "/", - Expires: time.Now().Add(s.config.CSRFTokenExpiry), - Secure: s.config.CSRFCookieSecure, - HttpOnly: s.config.CSRFCookieHTTPOnly, - SameSite: s.config.CSRFCookieSameSite, - }) -} - -// ValidateCSRFToken validates a CSRF token -func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool { - // Get token from cookie - cookie, err := r.Cookie(s.config.CSRFCookieName) - if err != nil { - return false - } - cookieToken := cookie.Value - - // Get token from header - headerToken := r.Header.Get(s.config.CSRFHeaderName) - - // Compare tokens - return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1 -} - -// CSRFMiddleware creates a middleware that validates CSRF tokens -func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Skip for GET, HEAD, OPTIONS, TRACE - if r.Method == http.MethodGet || - r.Method == http.MethodHead || - r.Method == http.MethodOptions || - r.Method == http.MethodTrace { - next.ServeHTTP(w, r) - return - } - - // Validate CSRF token - if !s.ValidateCSRFToken(r) { - http.Error(w, "Invalid CSRF token", http.StatusForbidden) - return - } - - next.ServeHTTP(w, r) - }) -} - -// RateLimiter represents a rate limiter -type RateLimiter struct { - requests map[string][]time.Time - config *Config -} - -// NewRateLimiter creates a new RateLimiter -func NewRateLimiter(config *Config) *RateLimiter { - return &RateLimiter{ - requests: make(map[string][]time.Time), - config: config, - } -} - -// Allow checks if a request is allowed -func (r *RateLimiter) Allow(key string) bool { - now := time.Now() - windowStart := now.Add(-r.config.RateLimitWindow) - - // Get requests for key - requests := r.requests[key] - - // Filter out old requests - var newRequests []time.Time - for _, t := range requests { - if t.After(windowStart) { - newRequests = append(newRequests, t) - } - } - - // Check if rate limit is exceeded - if len(newRequests) >= r.config.RateLimitRequests { - return false - } - - // Add current request - newRequests = append(newRequests, now) - r.requests[key] = newRequests - - return true -} - -// RateLimitMiddleware creates a middleware that limits request rate -func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Get client IP - ip := getClientIP(r) - - // Check if request is allowed - if !limiter.Allow(ip) { - http.Error(w, "Too many requests", http.StatusTooManyRequests) - return - } - - next.ServeHTTP(w, r) - }) - } -} - -// getClientIP gets the client IP address -func getClientIP(r *http.Request) string { - // Check X-Forwarded-For header - xForwardedFor := r.Header.Get("X-Forwarded-For") - if xForwardedFor != "" { - // X-Forwarded-For can contain multiple IPs, use the first one - ips := strings.Split(xForwardedFor, ",") - return strings.TrimSpace(ips[0]) - } - - // Check X-Real-IP header - xRealIP := r.Header.Get("X-Real-IP") - if xRealIP != "" { - return xRealIP - } - - // Use RemoteAddr - return r.RemoteAddr -} - -// HashPassword hashes a password using Argon2 -func (s *SecurityService) HashPassword(password string) (string, error) { - // Generate salt - salt := make([]byte, s.config.Argon2SaltLen) - if _, err := rand.Read(salt); err != nil { - return "", fmt.Errorf("failed to generate salt: %w", err) - } - - // Hash password - hash := argon2.IDKey( - []byte(password), - salt, - s.config.Argon2Time, - s.config.Argon2Memory, - s.config.Argon2Threads, - s.config.Argon2KeyLen, - ) - - // Encode as base64 - saltBase64 := base64.StdEncoding.EncodeToString(salt) - hashBase64 := base64.StdEncoding.EncodeToString(hash) - - // Format as $argon2id$v=19$m=65536,t=1,p=4$$ - return fmt.Sprintf( - "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", - s.config.Argon2Memory, - s.config.Argon2Time, - s.config.Argon2Threads, - saltBase64, - hashBase64, - ), nil -} - -// VerifyPassword verifies a password against a hash -func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) { - // Parse encoded hash - parts := strings.Split(encodedHash, "$") - if len(parts) != 6 { - return false, fmt.Errorf("invalid hash format") - } - - // Extract parameters - if parts[1] != "argon2id" { - return false, fmt.Errorf("unsupported hash algorithm") - } - - var memory uint32 - var time uint32 - var threads uint8 - _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) - if err != nil { - return false, fmt.Errorf("failed to parse hash parameters: %w", err) - } - - // Decode salt and hash - salt, err := base64.StdEncoding.DecodeString(parts[4]) - if err != nil { - return false, fmt.Errorf("failed to decode salt: %w", err) - } - - hash, err := base64.StdEncoding.DecodeString(parts[5]) - if err != nil { - return false, fmt.Errorf("failed to decode hash: %w", err) - } - - // Hash password with same parameters - newHash := argon2.IDKey( - []byte(password), - salt, - time, - memory, - threads, - uint32(len(hash)), - ) - - // Compare hashes - return subtle.ConstantTimeCompare(hash, newHash) == 1, nil -} - -// SecureHeadersMiddleware adds security headers to responses -func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Add security headers - w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("X-Frame-Options", "DENY") - w.Header().Set("X-XSS-Protection", "1; mode=block") - w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") - w.Header().Set("Content-Security-Policy", "default-src 'self'") - w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") - - next.ServeHTTP(w, r) - }) -} - -// contextKey is a type for context keys -type contextKey string - -// userIDKey is the key for user ID in context -const userIDKey = contextKey("user_id") - -// WithUserID adds a user ID to the context -func WithUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, userIDKey, userID) -} - -// GetUserID gets the user ID from the context -func GetUserID(ctx context.Context) (string, bool) { - userID, ok := ctx.Value(userIDKey).(string) - return userID, ok -} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index a10e990..0000000 --- a/server/server.go +++ /dev/null @@ -1,190 +0,0 @@ -package server - -import ( - "context" - "fmt" - "log" - "net/http" - "os" - "os/signal" - "runtime" - "syscall" - "time" - - "otpm/config" - "otpm/middleware" - - "github.com/julienschmidt/httprouter" -) - -// Server represents the HTTP server -type Server struct { - server *http.Server - router *httprouter.Router - config *config.Config -} - -// New creates a new server -func New(cfg *config.Config) *Server { - router := httprouter.New() - - server := &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.Server.Port), - Handler: router, - ReadTimeout: cfg.Server.ReadTimeout, - WriteTimeout: cfg.Server.WriteTimeout, - IdleTimeout: 120 * time.Second, - } - - return &Server{ - server: server, - router: router, - config: cfg, - } -} - -// Start starts the server -func (s *Server) Start() error { - // Apply global middleware in correct order with enhanced error handling - var handler http.Handler = s.router - - // Logger should be first to capture all request details - handler = middleware.Logger(handler) - - // CORS next to handle pre-flight requests - handler = middleware.CORS(handler) - - // Then Timeout to enforce request deadlines - handler = middleware.Timeout(s.config.Server.Timeout)(handler) - - // Recover should be outermost to catch any panics - handler = middleware.Recover(handler) - - s.server.Handler = handler - - // Log server configuration at startup - log.Printf("Server configuration:\n"+ - "Address: %s\n"+ - "Read Timeout: %v\n"+ - "Write Timeout: %v\n"+ - "Idle Timeout: %v\n"+ - "Request Timeout: %v", - s.server.Addr, - s.server.ReadTimeout, - s.server.WriteTimeout, - s.server.IdleTimeout, - s.config.Server.Timeout, - ) - - // Start server in a goroutine - serverErr := make(chan error, 1) - go func() { - log.Printf("Server starting on %s", s.server.Addr) - if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - serverErr <- fmt.Errorf("server error: %w", err) - } - }() - - // Wait for interrupt signal or server error - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - - select { - case err := <-serverErr: - return err - case <-quit: - return s.Shutdown() - } -} - -// Shutdown gracefully stops the server -func (s *Server) Shutdown() error { - log.Println("Shutting down server...") - - ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout) - defer cancel() - - if err := s.server.Shutdown(ctx); err != nil { - return fmt.Errorf("graceful shutdown failed: %w", err) - } - - log.Println("Server stopped gracefully") - return nil -} - -// Router returns the router -func (s *Server) Router() *httprouter.Router { - return s.router -} - -// RegisterRoutes registers all routes -func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) { - for pattern, handler := range routes { - s.router.Handle("GET", pattern, handler) - s.router.Handle("POST", pattern, handler) - s.router.Handle("PUT", pattern, handler) - s.router.Handle("DELETE", pattern, handler) - } -} - -// RegisterAuthRoutes registers routes that require authentication -func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) { - for pattern, handler := range routes { - // Apply authentication middleware - authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - // Convert httprouter.Handle to http.HandlerFunc for middleware - wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Store params in request context - ctx := context.WithValue(r.Context(), "params", ps) - handler(w, r.WithContext(ctx), ps) - }) - - // Apply auth middleware - middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r) - } - - s.router.Handle("GET", pattern, authHandler) - s.router.Handle("POST", pattern, authHandler) - s.router.Handle("PUT", pattern, authHandler) - s.router.Handle("DELETE", pattern, authHandler) - } -} - -// RegisterHealthCheck registers an enhanced health check endpoint -func (s *Server) RegisterHealthCheck() { - s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - response := map[string]interface{}{ - "status": "ok", - "timestamp": time.Now().Format(time.RFC3339), - "version": "1.0.0", // Hardcoded version instead of from config - "system": map[string]interface{}{ - "goroutines": runtime.NumGoroutine(), - "memory": getMemoryUsage(), - }, - } - - // Add database status if configured - if s.config.Database.DSN != "" { - dbStatus := "ok" - response["database"] = dbStatus - } - - middleware.SuccessResponse(w, response) - }) -} - -// getMemoryUsage returns current memory usage in MB -func getMemoryUsage() map[string]interface{} { - var m runtime.MemStats - runtime.ReadMemStats(&m) - return map[string]interface{}{ - "alloc_mb": bToMb(m.Alloc), - "total_alloc_mb": bToMb(m.TotalAlloc), - "sys_mb": bToMb(m.Sys), - "num_gc": m.NumGC, - } -} - -func bToMb(b uint64) float64 { - return float64(b) / 1024 / 1024 -} diff --git a/services/auth.go b/services/auth.go deleted file mode 100644 index 9ad2a85..0000000 --- a/services/auth.go +++ /dev/null @@ -1,230 +0,0 @@ -package services - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "time" - - "github.com/golang-jwt/jwt" - "github.com/google/uuid" - - "otpm/config" - "otpm/models" -) - -// WeChatCode2SessionResponse represents the response from WeChat code2session API -type WeChatCode2SessionResponse struct { - OpenID string `json:"openid"` - SessionKey string `json:"session_key"` - UnionID string `json:"unionid"` - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` -} - -// AuthService handles authentication related operations -type AuthService struct { - config *config.Config - userRepo *models.UserRepository - httpClient *http.Client -} - -// NewAuthService creates a new AuthService -func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService { - return &AuthService{ - config: cfg, - userRepo: userRepo, - httpClient: &http.Client{ - Timeout: 10 * time.Second, - }, - } -} - -// LoginWithWeChatCode handles WeChat login -func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) { - start := time.Now() - - // Get OpenID and SessionKey from WeChat - sessionInfo, err := s.getWeChatSession(code) - if err != nil { - log.Printf("WeChat login failed for code %s: %v", maskCode(code), err) - return "", fmt.Errorf("failed to get WeChat session: %w", err) - } - log.Printf("WeChat session obtained for code %s (took %v)", - maskCode(code), time.Since(start)) - - // Find or create user - user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID) - if err != nil { - log.Printf("User lookup failed for OpenID %s: %v", - maskOpenID(sessionInfo.OpenID), err) - return "", fmt.Errorf("failed to find user: %w", err) - } - - if user == nil { - // Create new user - user = &models.User{ - ID: uuid.New().String(), - OpenID: sessionInfo.OpenID, - SessionKey: sessionInfo.SessionKey, - } - if err := s.userRepo.Create(ctx, user); err != nil { - log.Printf("User creation failed for OpenID %s: %v", - maskOpenID(sessionInfo.OpenID), err) - return "", fmt.Errorf("failed to create user: %w", err) - } - log.Printf("New user created with ID %s for OpenID %s", - user.ID, maskOpenID(sessionInfo.OpenID)) - } else { - // Update session key - user.SessionKey = sessionInfo.SessionKey - if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("User update failed for ID %s: %v", user.ID, err) - return "", fmt.Errorf("failed to update user: %w", err) - } - log.Printf("User %s session key updated", user.ID) - } - - // Generate JWT token - token, err := s.generateToken(user) - if err != nil { - log.Printf("Token generation failed for user %s: %v", user.ID, err) - return "", fmt.Errorf("failed to generate token: %w", err) - } - - log.Printf("WeChat login completed for user %s (total time %v)", - user.ID, time.Since(start)) - return token, nil -} - -// maskCode masks sensitive parts of WeChat code for logging -func maskCode(code string) string { - if len(code) < 8 { - return "****" - } - return code[:2] + "****" + code[len(code)-2:] -} - -// maskOpenID masks sensitive parts of OpenID for logging -func maskOpenID(openID string) string { - if len(openID) < 8 { - return "****" - } - return openID[:2] + "****" + openID[len(openID)-2:] -} - -// getWeChatSession calls WeChat's code2session API -func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) { - url := fmt.Sprintf( - "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", - s.config.WeChat.AppID, - s.config.WeChat.AppSecret, - code, - ) - - resp, err := s.httpClient.Get(url) - if err != nil { - return nil, fmt.Errorf("failed to call WeChat API: %w", err) - } - defer resp.Body.Close() - - var result WeChatCode2SessionResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode WeChat response: %w", err) - } - - if result.ErrCode != 0 { - return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg) - } - - return &result, nil -} - -// generateToken generates a JWT token for a user -func (s *AuthService) generateToken(user *models.User) (string, error) { - now := time.Now() - claims := jwt.MapClaims{ - "user_id": user.ID, - "exp": now.Add(s.config.JWT.ExpireDelta).Unix(), - "iat": now.Unix(), - "iss": s.config.JWT.Issuer, - "aud": s.config.JWT.Audience, - "token_id": uuid.New().String(), // Unique token ID for tracking - } - - // Use stronger signing method - token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) - signedToken, err := token.SignedString([]byte(s.config.JWT.Secret)) - if err != nil { - return "", fmt.Errorf("failed to sign token: %w", err) - } - - log.Printf("Token generated for user %s (expires at %v)", - user.ID, now.Add(s.config.JWT.ExpireDelta)) - return signedToken, nil -} - -// ValidateToken validates a JWT token with additional checks -func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) { - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // Verify signing method - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return []byte(s.config.JWT.Secret), nil - }) - - if err != nil { - if ve, ok := err.(*jwt.ValidationError); ok { - switch { - case ve.Errors&jwt.ValidationErrorMalformed != 0: - return nil, fmt.Errorf("malformed token") - case ve.Errors&jwt.ValidationErrorExpired != 0: - return nil, fmt.Errorf("token expired") - case ve.Errors&jwt.ValidationErrorNotValidYet != 0: - return nil, fmt.Errorf("token not active yet") - default: - return nil, fmt.Errorf("token validation error: %w", err) - } - } - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - // Additional claims validation - if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - // Check issuer - if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer { - return nil, fmt.Errorf("invalid token issuer") - } - // Check audience - if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience { - return nil, fmt.Errorf("invalid token audience") - } - } else { - return nil, fmt.Errorf("invalid token claims") - } - - return token, nil -} - -// GetUserFromToken gets user information from a JWT token -func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) { - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("invalid token claims") - } - - userID, ok := claims["user_id"].(string) - if !ok { - return nil, fmt.Errorf("user_id not found in token") - } - - user, err := s.userRepo.FindByID(ctx, userID) - if err != nil { - return nil, fmt.Errorf("failed to find user: %w", err) - } - - return user, nil -} diff --git a/services/otp.go b/services/otp.go deleted file mode 100644 index c345bea..0000000 --- a/services/otp.go +++ /dev/null @@ -1,358 +0,0 @@ -package services - -import ( - "context" - "crypto/hmac" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/base32" - "encoding/binary" - "fmt" - "hash" - "log" - "strings" - "time" - - "otpm/models" - - "github.com/google/uuid" -) - -// OTPService handles OTP related operations -type OTPService struct { - otpRepo *models.OTPRepository -} - -// NewOTPService creates a new OTPService -func NewOTPService(otpRepo *models.OTPRepository) *OTPService { - return &OTPService{ - otpRepo: otpRepo, - } -} - -// CreateOTP creates a new OTP with performance monitoring and logging -func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) { - start := time.Now() - - // Validate input - if err := s.validateOTPInput(input); err != nil { - log.Printf("OTP validation failed for user %s: %v", userID, err) - return nil, err - } - - // Clean and standardize secret - secret := cleanSecret(input.Secret) - - // Set defaults for optional fields - algorithm := strings.ToUpper(input.Algorithm) - if algorithm == "" { - algorithm = "SHA1" - } - - digits := input.Digits - if digits == 0 { - digits = 6 - } - - period := input.Period - if period == 0 { - period = 30 - } - - // Create OTP - otp := &models.OTP{ - ID: uuid.New().String(), - UserID: userID, - Name: input.Name, - Issuer: input.Issuer, - Secret: secret, - Algorithm: algorithm, - Digits: digits, - Period: period, - } - - if err := s.otpRepo.Create(ctx, otp); err != nil { - log.Printf("Failed to create OTP for user %s: %v", userID, err) - return nil, fmt.Errorf("failed to create OTP: %w", err) - } - - // Log successful creation (without exposing secret) - log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)", - otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period) - - return otp, nil -} - -// GetOTP gets an OTP by ID -func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) { - otp, err := s.otpRepo.FindByID(ctx, id, userID) - if err != nil { - return nil, fmt.Errorf("failed to get OTP: %w", err) - } - return otp, nil -} - -// ListOTPs lists all OTPs for a user -func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) { - otps, err := s.otpRepo.FindAllByUserID(ctx, userID) - if err != nil { - return nil, fmt.Errorf("failed to list OTPs: %w", err) - } - return otps, nil -} - -// UpdateOTP updates an OTP -func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) { - // Get existing OTP - otp, err := s.otpRepo.FindByID(ctx, id, userID) - if err != nil { - return nil, fmt.Errorf("failed to get OTP: %w", err) - } - - // Update fields - if input.Name != "" { - otp.Name = input.Name - } - if input.Issuer != "" { - otp.Issuer = input.Issuer - } - if input.Algorithm != "" { - otp.Algorithm = strings.ToUpper(input.Algorithm) - } - if input.Digits > 0 { - otp.Digits = input.Digits - } - if input.Period > 0 { - otp.Period = input.Period - } - - // Validate updated OTP - if err := s.validateOTPInput(models.OTPParams{ - Name: otp.Name, - Issuer: otp.Issuer, - Secret: otp.Secret, - Algorithm: otp.Algorithm, - Digits: otp.Digits, - Period: otp.Period, - }); err != nil { - return nil, err - } - - if err := s.otpRepo.Update(ctx, otp); err != nil { - return nil, fmt.Errorf("failed to update OTP: %w", err) - } - - return otp, nil -} - -// DeleteOTP deletes an OTP -func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error { - if err := s.otpRepo.Delete(ctx, id, userID); err != nil { - return fmt.Errorf("failed to delete OTP: %w", err) - } - return nil -} - -// GenerateCode generates a TOTP code with enhanced logging and error handling -func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) { - start := time.Now() - - otp, err := s.otpRepo.FindByID(ctx, id, userID) - if err != nil { - log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err) - return "", 0, fmt.Errorf("failed to get OTP: %w", err) - } - - // Get current time step - now := time.Now().Unix() - timeStep := now / int64(otp.Period) - - // Generate code - code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits) - if err != nil { - log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err) - return "", 0, fmt.Errorf("failed to generate code: %w", err) - } - - // Calculate remaining seconds - remainingSeconds := otp.Period - int(now%int64(otp.Period)) - - // Log successful generation (without actual code) - log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)", - id, userID, time.Since(start), remainingSeconds) - - return code, remainingSeconds, nil -} - -// VerifyCode verifies a TOTP code with enhanced security and logging -func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) { - start := time.Now() - - // Basic input validation - if len(code) == 0 { - log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID) - return false, fmt.Errorf("code is required") - } - - otp, err := s.otpRepo.FindByID(ctx, id, userID) - if err != nil { - log.Printf("Failed to find OTP %s for user %s during verification: %v", - id, userID, err) - return false, fmt.Errorf("failed to get OTP: %w", err) - } - - // Get current and adjacent time steps - now := time.Now().Unix() - timeSteps := []int64{ - (now - int64(otp.Period)) / int64(otp.Period), - now / int64(otp.Period), - (now + int64(otp.Period)) / int64(otp.Period), - } - - // Check code against all time steps - for _, ts := range timeSteps { - expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits) - if err != nil { - log.Printf("Code generation failed for time step %d: %v", ts, err) - continue - } - if expectedCode == code { - // Log successful verification - log.Printf("Code verified successfully for OTP %s (user %s) in %v", - id, userID, time.Since(start)) - return true, nil - } - } - - // Log failed verification attempt - log.Printf("Invalid code provided for OTP %s (user %s) in %v", - id, userID, time.Since(start)) - - return false, nil -} - -// validateOTPInput validates OTP input with detailed error messages -func (s *OTPService) validateOTPInput(input models.OTPParams) error { - if input.Name == "" { - return fmt.Errorf("name is required") - } - - if len(input.Name) > 100 { - return fmt.Errorf("name is too long (maximum 100 characters)") - } - - if input.Secret == "" { - return fmt.Errorf("secret is required") - } - - if !isValidBase32(input.Secret) { - return fmt.Errorf("invalid secret format: must be a valid base32 string") - } - - // Secret length check (after base32 decoding) - secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "=")) - if len(secretBytes) < 10 { - return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)") - } - - if input.Algorithm != "" { - if !isValidAlgorithm(input.Algorithm) { - return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm) - } - } - - if input.Digits != 0 { - if input.Digits < 6 || input.Digits > 8 { - return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits) - } - } - - if input.Period != 0 { - if input.Period < 30 || input.Period > 60 { - return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period) - } - } - - return nil -} - -// Helper functions - -func cleanSecret(secret string) string { - // Remove spaces and convert to upper case - secret = strings.TrimSpace(strings.ToUpper(secret)) - // Remove any padding characters - return strings.TrimRight(secret, "=") -} - -func isValidBase32(s string) bool { - // Try to decode the secret - _, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "=")) - return err == nil -} - -func isValidAlgorithm(algorithm string) bool { - switch strings.ToUpper(algorithm) { - case "SHA1", "SHA256", "SHA512": - return true - default: - return false - } -} - -func getHasher(algorithm string, key []byte) (hash.Hash, error) { - switch strings.ToUpper(algorithm) { - case "SHA1": - return hmac.New(sha1.New, key), nil - case "SHA256": - return hmac.New(sha256.New, key), nil - case "SHA512": - return hmac.New(sha512.New, key), nil - default: - return nil, fmt.Errorf("unsupported algorithm: %s", algorithm) - } -} - -func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) { - // Decode secret - secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "=")) - if err != nil { - return "", fmt.Errorf("invalid secret: %w", err) - } - - // Get initialized HMAC hasher with secret - hasher, err := getHasher(algorithm, secretBytes) - if err != nil { - return "", err - } - - // Convert time step to bytes - timeBytes := make([]byte, 8) - binary.BigEndian.PutUint64(timeBytes, uint64(timeStep)) - - // Calculate HMAC - hasher.Write(timeBytes) - hash := hasher.Sum(nil) - - // Get offset - offset := hash[len(hash)-1] & 0xf - - // Generate 4-byte code - code := binary.BigEndian.Uint32(hash[offset : offset+4]) - code = code & 0x7fffffff - - // Get the specified number of digits - code = code % uint32(pow10(digits)) - - // Format code with leading zeros - return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil -} - -func pow10(n int) uint32 { - result := uint32(1) - for i := 0; i < n; i++ { - result *= 10 - } - return result -} diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index efd6713..0000000 --- a/utils/utils.go +++ /dev/null @@ -1,105 +0,0 @@ -package utils - -import ( - "context" - "crypto/aes" - "crypto/cipher" - "encoding/base64" - "fmt" - "net/http" - "strings" - - "github.com/golang-jwt/jwt" - "github.com/julienschmidt/httprouter" - "github.com/spf13/viper" -) - -// AdaptHandler函数将一个http.Handler转换为httprouter.Handle -func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle { - // 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数 - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - // 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递 - h(w, r) - } -} - -func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized) - return - } - - tokenStr := strings.TrimPrefix(authHeader, "Bearer ") - secret := viper.GetString("auth.secret") - - token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method") - } - return []byte(secret), nil - }) - - if err != nil || !token.Valid { - http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized) - return - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized) - return - } - - type contextKey string - // 将 openid 存入上下文 - ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"]) - next.ServeHTTP(w, r.WithContext(ctx)) - } -} - -// AesDecrypt 函数用于AES解密 -func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) { - //Base64解码 - keyBytes, err := base64.StdEncoding.DecodeString(sessionKey) - if err != nil { - return nil, err - } - ivBytes, err := base64.StdEncoding.DecodeString(iv) - if err != nil { - return nil, err - } - cryptData, err := base64.StdEncoding.DecodeString(encryptedData) - if err != nil { - return nil, err - } - origData := make([]byte, len(cryptData)) - //AES - block, err := aes.NewCipher(keyBytes) - if err != nil { - return nil, err - } - //CBC - mode := cipher.NewCBCDecrypter(block, ivBytes) - //解密 - mode.CryptBlocks(origData, cryptData) - //去除填充位 - origData = PKCS7UnPadding(origData) - return origData, nil -} - -// PKCS7UnPadding 函数用于去除PKCS7填充的密文 -func PKCS7UnPadding(plantText []byte) []byte { - // 获取密文的长度 - length := len(plantText) - // 如果密文长度大于0 - if length > 0 { - // 获取最后一个字节的值,即填充的位数 - unPadding := int(plantText[length-1]) - // 返回去除填充后的密文 - return plantText[:(length - unPadding)] - } - // 如果密文长度为0,则返回原密文 - return plantText -} diff --git a/validator/validator.go b/validator/validator.go deleted file mode 100644 index 13cff58..0000000 --- a/validator/validator.go +++ /dev/null @@ -1,316 +0,0 @@ -package validator - -import ( - "encoding/json" - "fmt" - "net/http" - "reflect" - "regexp" - "strings" - - "github.com/go-playground/validator/v10" -) - -var ( - validate *validator.Validate - - // 自定义验证规则 - customValidations = map[string]validator.Func{ - "otpsecret": validateOTPSecret, - "password": validatePassword, - "issuer": validateIssuer, - "otpauth_uri": validateOTPAuthURI, - "no_xss": validateNoXSS, - } - - // 常见的弱密码列表(实际使用时应该使用更完整的列表) - commonPasswords = map[string]bool{ - "password123": true, - "12345678": true, - "qwerty123": true, - "admin123": true, - "letmein": true, - "welcome": true, - "password": true, - "admin": true, - } - - // 预编译的XSS检测正则表达式 - xssPatterns = []*regexp.Regexp{ - regexp.MustCompile(`(?i)]*>.*?`), - regexp.MustCompile(`(?i)javascript:`), - regexp.MustCompile(`(?i)data:text/html`), - regexp.MustCompile(`(?i)on\w+\s*=`), - regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), - regexp.MustCompile(`(?i)<\s*iframe`), - regexp.MustCompile(`(?i)<\s*object`), - regexp.MustCompile(`(?i)<\s*embed`), - regexp.MustCompile(`(?i)<\s*style`), - regexp.MustCompile(`(?i)<\s*form`), - regexp.MustCompile(`(?i)<\s*applet`), - regexp.MustCompile(`(?i)<\s*meta`), - regexp.MustCompile(`(?i)expression\s*\(`), - regexp.MustCompile(`(?i)url\s*\(`), - } - - // 预编译的正则表达式 - base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`) - issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`) - otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`) - upperRegex = regexp.MustCompile(`[A-Z]`) - lowerRegex = regexp.MustCompile(`[a-z]`) - numberRegex = regexp.MustCompile(`[0-9]`) - specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`) -) - -func init() { - validate = validator.New() - - // 注册自定义验证规则 - for tag, fn := range customValidations { - if err := validate.RegisterValidation(tag, fn); err != nil { - panic(fmt.Sprintf("failed to register validation %s: %v", tag, err)) - } - } - - // 使用json tag作为字段名 - validate.RegisterTagNameFunc(func(fld reflect.StructField) string { - name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] - if name == "-" { - return "" - } - return name - }) -} - -// ValidateRequest validates a request body against a struct -func ValidateRequest(r *http.Request, v interface{}) error { - if err := json.NewDecoder(r.Body).Decode(v); err != nil { - return fmt.Errorf("invalid request body: %w", err) - } - - if err := validate.Struct(v); err != nil { - if validationErrors, ok := err.(validator.ValidationErrors); ok { - return NewValidationError(validationErrors) - } - return fmt.Errorf("validation error: %w", err) - } - - return nil -} - -// ValidationError represents a validation error -type ValidationError struct { - Fields map[string]string `json:"fields"` -} - -// Error implements the error interface -func (e *ValidationError) Error() string { - var errors []string - for field, msg := range e.Fields { - errors = append(errors, fmt.Sprintf("%s: %s", field, msg)) - } - return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; ")) -} - -// NewValidationError creates a new ValidationError from validator.ValidationErrors -func NewValidationError(errors validator.ValidationErrors) *ValidationError { - fields := make(map[string]string) - for _, err := range errors { - fields[err.Field()] = getErrorMessage(err) - } - return &ValidationError{Fields: fields} -} - -// getErrorMessage returns a human-readable error message for a validation error -func getErrorMessage(err validator.FieldError) string { - switch err.Tag() { - case "required": - return "此字段为必填项" - case "email": - return "请输入有效的电子邮件地址" - case "min": - if err.Type().Kind() == reflect.String { - return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param()) - } - return fmt.Sprintf("必须大于或等于 %s", err.Param()) - case "max": - if err.Type().Kind() == reflect.String { - return fmt.Sprintf("长度不能超过 %s 个字符", err.Param()) - } - return fmt.Sprintf("必须小于或等于 %s", err.Param()) - case "len": - return fmt.Sprintf("长度必须为 %s 个字符", err.Param()) - case "oneof": - return fmt.Sprintf("必须是以下值之一: %s", err.Param()) - case "otpsecret": - return "OTP密钥格式无效,必须是有效的Base32编码" - case "password": - return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符" - case "issuer": - return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号" - case "otpauth_uri": - return "OTP认证URI格式无效" - case "no_xss": - return "输入包含潜在的不安全内容" - case "numeric": - return "必须是数字" - default: - return fmt.Sprintf("验证失败: %s", err.Tag()) - } -} - -// Custom validation functions - -// validateOTPSecret validates an OTP secret -func validateOTPSecret(fl validator.FieldLevel) bool { - secret := fl.Field().String() - - if secret == "" { - return false - } - - // OTP secret should be base32 encoded - if !base32Regex.MatchString(secret) { - return false - } - - // Check length (typical OTP secrets are 16-64 characters) - validLength := len(secret) >= 16 && len(secret) <= 128 - - return validLength -} - -// validatePassword validates a password -func validatePassword(fl validator.FieldLevel) bool { - password := fl.Field().String() - - // At least 10 characters long - if len(password) < 10 { - return false - } - - // Check if it's a common password - if commonPasswords[strings.ToLower(password)] { - return false - } - - // Check character types - hasUpper := upperRegex.MatchString(password) - hasLower := lowerRegex.MatchString(password) - hasNumber := numberRegex.MatchString(password) - hasSpecial := specialRegex.MatchString(password) - - // Ensure password has enough complexity - complexity := 0 - if hasUpper { - complexity++ - } - if hasLower { - complexity++ - } - if hasNumber { - complexity++ - } - if hasSpecial { - complexity++ - } - - return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial) -} - -// validateIssuer validates an issuer name -func validateIssuer(fl validator.FieldLevel) bool { - issuer := fl.Field().String() - - if issuer == "" { - return false - } - - // Issuer should not contain special characters that could cause problems in URLs - if !issuerRegex.MatchString(issuer) { - return false - } - - // Check length - validLength := len(issuer) >= 1 && len(issuer) <= 100 - - return validLength -} - -// validateOTPAuthURI validates an otpauth URI -func validateOTPAuthURI(fl validator.FieldLevel) bool { - uri := fl.Field().String() - - if uri == "" { - return false - } - - // Basic format check for otpauth URI - // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD - return otpauthRegex.MatchString(uri) -} - -// validateNoXSS checks if a string contains potential XSS payloads -func validateNoXSS(fl validator.FieldLevel) bool { - value := fl.Field().String() - - // 检查基本的HTML编码 - if strings.Contains(value, "&#") || - strings.Contains(value, "<") || - strings.Contains(value, ">") { - return false - } - - // 检查十六进制编码 - if strings.Contains(strings.ToLower(value), "\\x3c") || // < - strings.Contains(strings.ToLower(value), "\\x3e") { // > - return false - } - - // 检查Unicode编码 - if strings.Contains(strings.ToLower(value), "\\u003c") || // < - strings.Contains(strings.ToLower(value), "\\u003e") { // > - return false - } - - // 使用预编译的正则表达式检查XSS模式 - for _, pattern := range xssPatterns { - if pattern.MatchString(value) { - return false - } - } - - return true -} - -// Request validation structs - -// LoginRequest represents a login request -type LoginRequest struct { - Code string `json:"code" validate:"required,len=6|len=8,numeric"` -} - -// CreateOTPRequest represents a request to create an OTP -type CreateOTPRequest struct { - Name string `json:"name" validate:"required,min=1,max=100,no_xss"` - Issuer string `json:"issuer" validate:"required,issuer,no_xss"` - Secret string `json:"secret" validate:"required,otpsecret"` - Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` - Digits int `json:"digits" validate:"required,oneof=6 8"` - Period int `json:"period" validate:"required,oneof=30 60"` -} - -// UpdateOTPRequest represents a request to update an OTP -type UpdateOTPRequest struct { - Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"` - Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"` - Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` - Digits int `json:"digits" validate:"omitempty,oneof=6 8"` - Period int `json:"period" validate:"omitempty,oneof=30 60"` -} - -// VerifyOTPRequest represents a request to verify an OTP code -type VerifyOTPRequest struct { - Code string `json:"code" validate:"required,len=6|len=8,numeric"` -}