add branch v1

This commit is contained in:
“xHuPo” 2025-06-09 11:20:07 +08:00
parent 5d370e1077
commit 01b8951dd5
53 changed files with 1079 additions and 6481 deletions

View file

@ -1,149 +0,0 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"net/http"
)
// Common error codes
const (
CodeSuccess = 0
CodeInvalidParams = 400
CodeUnauthorized = 401
CodeForbidden = 403
CodeNotFound = 404
CodeInternalError = 500
CodeServiceUnavail = 503
)
// Error represents an API error
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
}
// Error implements the error interface
func (e *Error) Error() string {
return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
}
// NewError creates a new API error
func NewError(code int, message string) *Error {
return &Error{
Code: code,
Message: message,
}
}
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ResponseWriter wraps common response writing functions
type ResponseWriter struct {
http.ResponseWriter
}
// NewResponseWriter creates a new ResponseWriter
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{w}
}
// WriteJSON writes a JSON response
func (w *ResponseWriter) WriteJSON(code int, data interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
return json.NewEncoder(w).Encode(data)
}
// WriteSuccess writes a success response
func (w *ResponseWriter) WriteSuccess(data interface{}) error {
return w.WriteJSON(http.StatusOK, Response{
Code: CodeSuccess,
Message: "success",
Data: data,
})
}
// WriteError writes an error response
func (w *ResponseWriter) WriteError(err error) error {
var apiErr *Error
if errors.As(err, &apiErr) {
return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{
Code: apiErr.Code,
Message: apiErr.Message,
})
}
// Handle unknown errors
return w.WriteJSON(http.StatusInternalServerError, Response{
Code: CodeInternalError,
Message: "Internal Server Error",
})
}
// WriteErrorWithCode writes an error response with a specific code
func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error {
return w.WriteJSON(getHTTPStatus(code), Response{
Code: code,
Message: message,
})
}
// getHTTPStatus maps API error codes to HTTP status codes
func getHTTPStatus(code int) int {
switch code {
case CodeSuccess:
return http.StatusOK
case CodeInvalidParams:
return http.StatusBadRequest
case CodeUnauthorized:
return http.StatusUnauthorized
case CodeForbidden:
return http.StatusForbidden
case CodeNotFound:
return http.StatusNotFound
case CodeServiceUnavail:
return http.StatusServiceUnavailable
default:
return http.StatusInternalServerError
}
}
// Common errors
var (
ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters")
ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized")
ErrForbidden = NewError(CodeForbidden, "Forbidden")
ErrNotFound = NewError(CodeNotFound, "Resource not found")
ErrInternalError = NewError(CodeInternalError, "Internal server error")
ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable")
)
// ValidationError creates an error for invalid parameters
func ValidationError(message string) *Error {
return NewError(CodeInvalidParams, message)
}
// NotFoundError creates an error for not found resources
func NotFoundError(resource string) *Error {
return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource))
}
// ForbiddenError creates an error for forbidden actions
func ForbiddenError(message string) *Error {
return NewError(CodeForbidden, message)
}
// InternalError creates an error for internal server errors
func InternalError(err error) *Error {
if err == nil {
return ErrInternalError
}
return NewError(CodeInternalError, err.Error())
}

View file

@ -1,152 +0,0 @@
package api
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
// Validate is a global validator instance
var Validate = validator.New()
// RegisterCustomValidations registers custom validation functions
func RegisterCustomValidations() {
// Register custom validation for issuer
Validate.RegisterValidation("issuer", validateIssuer)
// Register custom validation for XSS prevention
Validate.RegisterValidation("no_xss", validateNoXSS)
// Register custom validation for OTP secret
Validate.RegisterValidation("otpsecret", validateOTPSecret)
}
// validateOTPSecret validates that the OTP secret is in valid base32 format
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// Check if the secret is not empty
if secret == "" {
return false
}
// Check if the secret is in base32 format (A-Z, 2-7)
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
if !base32Regex.MatchString(secret) {
return false
}
// Check if the length is valid (must be at least 16 characters)
if len(secret) < 16 || len(secret) > 128 {
return false
}
return true
}
// validateIssuer validates that the issuer field contains only allowed characters
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
// Empty issuer is valid (since it's optional)
if issuer == "" {
return true
}
// Allow alphanumeric characters, spaces, and common punctuation
issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
// Validate is a global validator instance
var Validate = validator.New()
// RegisterCustomValidations registers custom validation functions
func RegisterCustomValidations() {
// Register custom validation for issuer
Validate.RegisterValidation("issuer", validateIssuer)
// Register custom validation for XSS prevention
Validate.RegisterValidation("no_xss", validateNoXSS)
// Register custom validation for OTP secret
Validate.RegisterValidation("otpsecret", validateOTPSecret)
}
// validateOTPSecret validates that the OTP secret is in valid base32 format
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// Check if the secret is not empty
if secret == "" {
return false
}
// Check if the secret is in base32 format (A-Z, 2-7)
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
if !base32Regex.MatchString(secret) {
return false
}
// Check if the length is valid (must be at least 16 characters)
if len(secret) < 16 || len(secret) > 128 {
return false
}
return true
}
)
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
if len(issuer) > 100 {
return false
}
return true
}
// validateNoXSS validates that the field doesn't contain potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// Check for HTML encoding
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// Check for common XSS patterns
suspiciousPatterns := []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
}
for _, pattern := range suspiciousPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}

202
api_server.go Normal file
View file

@ -0,0 +1,202 @@
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"auth"
"config"
"github.com/gorilla/mux"
)
// Server represents the API server
type Server struct {
config *config.Config
router *mux.Router
}
type WechatLoginResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
// NewServer creates a new instance of Server
func NewServer(cfg *config.Config) *Server {
s := &Server{
config: cfg,
router: mux.NewRouter(),
}
s.setupRoutes()
return s
}
// setupRoutes configures all the routes for the server
func (s *Server) setupRoutes() {
// 公开路由
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
// 受保护路由(需要JWT)
authRouter := s.router.PathPrefix("").Subrouter()
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
// 添加CORS中间件
s.router.Use(s.corsMiddleware)
}
// corsMiddleware handles CORS
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Check if the origin is allowed
allowed := false
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
if origin == allowedOrigin {
allowed = true
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods",
joinStrings(s.config.CORS.AllowedMethods))
w.Header().Set("Access-Control-Allow-Headers",
joinStrings(s.config.CORS.AllowedHeaders))
}
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// joinStrings joins string slice with commas
func joinStrings(slice []string) string {
result := ""
for i, s := range slice {
if i > 0 {
result += ", "
}
result += s
}
return result
}
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 读取请求体
body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
var req struct {
Code string `json:"code"`
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// 向微信服务器请求session_key
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
resp, err := http.Get(url)
if err != nil {
http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
body, err = ioutil.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Wechat service error", http.StatusInternalServerError)
return
}
var wechatResp WechatLoginResponse
if err := json.Unmarshal(body, &wechatResp); err != nil {
http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
return
}
if wechatResp.ErrCode != 0 {
http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
return
}
// 生成JWT token
token, err := s.generateSessionToken(wechatResp.OpenID)
if err != nil {
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
return
}
// 返回响应
response := map[string]interface{}{
"token": token,
"openid": wechatResp.OpenID,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
func (s *Server) generateSessionToken(openid string) (string, error) {
return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
}
// Start starts the HTTP server
func (s *Server) Start() error {
addr := fmt.Sprintf(":%d", s.config.Server.Port)
log.Printf("Starting server on %s", addr)
srv := &http.Server{
Handler: s.router,
Addr: addr,
WriteTimeout: s.config.Server.Timeout,
ReadTimeout: s.config.Server.Timeout,
}
return srv.ListenAndServe()
}
func main() {
// 加载配置
cfg, err := config.LoadConfig("config")
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化数据库连接
log.Println("Initializing database connection...")
if err := InitDB(cfg.Database); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
log.Println("Database connection established successfully")
// 创建并启动服务器
server := NewServer(cfg)
if err := server.Start(); err != nil {
log.Fatalf("Server failed to start: %v", err)
}
}

105
auth/middleware.go Normal file
View file

@ -0,0 +1,105 @@
package auth
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
type contextKey string
const (
UserIDContextKey contextKey = "userID"
)
// Claims represents the JWT claims
type Claims struct {
UserID string `json:"user_id"`
jwt.RegisteredClaims
}
// GenerateToken generates a new JWT token
func GenerateToken(userID string, signingKey string, expiry time.Duration) (string, error) {
claims := Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(signingKey))
}
// NewAuthMiddleware creates a new authentication middleware
func NewAuthMiddleware(signingKey string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
// Check if the header has the correct format
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
return
}
tokenString := parts[1]
// Parse and validate the token
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(signingKey), nil
})
if err != nil {
if err == jwt.ErrSignatureInvalid {
http.Error(w, "Invalid token signature", http.StatusUnauthorized)
return
}
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
if !token.Valid {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// Add user ID to request context
ctx := context.WithValue(r.Context(), UserIDContextKey, claims.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// GetUserIDFromContext extracts the user ID from the request context
func GetUserIDFromContext(ctx context.Context) (string, error) {
userID, ok := ctx.Value(UserIDContextKey).(string)
if !ok {
return "", fmt.Errorf("user ID not found in context")
}
return userID, nil
}
// RequireAuth is a middleware that ensures a valid JWT token is present
func RequireAuth(signingKey string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authMiddleware := NewAuthMiddleware(signingKey)
authMiddleware(http.HandlerFunc(next)).ServeHTTP(w, r)
}
}

206
cache/cache.go vendored
View file

@ -1,206 +0,0 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
)
// Item represents a cache item
type Item struct {
Value []byte
Expiration int64
}
// Expired returns true if the item has expired
func (item Item) Expired() bool {
if item.Expiration == 0 {
return false
}
return time.Now().UnixNano() > item.Expiration
}
// Cache represents an in-memory cache
type Cache struct {
items map[string]Item
mu sync.RWMutex
defaultExpiration time.Duration
cleanupInterval time.Duration
stopCleanup chan bool
}
// New creates a new cache with the given default expiration and cleanup interval
func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
cache := &Cache{
items: make(map[string]Item),
defaultExpiration: defaultExpiration,
cleanupInterval: cleanupInterval,
stopCleanup: make(chan bool),
}
// Start cleanup goroutine if cleanup interval > 0
if cleanupInterval > 0 {
go cache.startCleanup()
}
return cache
}
// startCleanup starts the cleanup process
func (c *Cache) startCleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.DeleteExpired()
case <-c.stopCleanup:
return
}
}
}
// Set adds an item to the cache with the given key and expiration
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error {
// Convert value to bytes
var valueBytes []byte
var err error
switch v := value.(type) {
case []byte:
valueBytes = v
case string:
valueBytes = []byte(v)
default:
valueBytes, err = json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
}
// Calculate expiration
var exp int64
if expiration == 0 {
if c.defaultExpiration > 0 {
exp = time.Now().Add(c.defaultExpiration).UnixNano()
}
} else if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
c.mu.Lock()
c.items[key] = Item{
Value: valueBytes,
Expiration: exp,
}
c.mu.Unlock()
return nil
}
// Get gets an item from the cache
func (c *Cache) Get(key string, value interface{}) (bool, error) {
c.mu.RLock()
item, found := c.items[key]
c.mu.RUnlock()
if !found {
return false, nil
}
// Check if item has expired
if item.Expired() {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return false, nil
}
// Unmarshal value
switch v := value.(type) {
case *[]byte:
*v = item.Value
case *string:
*v = string(item.Value)
default:
if err := json.Unmarshal(item.Value, value); err != nil {
return true, fmt.Errorf("failed to unmarshal value: %w", err)
}
}
return true, nil
}
// Delete deletes an item from the cache
func (c *Cache) Delete(key string) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
}
// DeleteExpired deletes all expired items from the cache
func (c *Cache) DeleteExpired() {
now := time.Now().UnixNano()
c.mu.Lock()
for k, v := range c.items {
if v.Expiration > 0 && now > v.Expiration {
delete(c.items, k)
}
}
c.mu.Unlock()
}
// Clear deletes all items from the cache
func (c *Cache) Clear() {
c.mu.Lock()
c.items = make(map[string]Item)
c.mu.Unlock()
}
// Close stops the cleanup goroutine
func (c *Cache) Close() {
if c.cleanupInterval > 0 {
c.stopCleanup <- true
}
}
// CacheService provides caching functionality
type CacheService struct {
cache *Cache
}
// NewCacheService creates a new CacheService
func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService {
return &CacheService{
cache: New(defaultExpiration, cleanupInterval),
}
}
// Set adds an item to the cache
func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return s.cache.Set(key, value, expiration)
}
// Get gets an item from the cache
func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) {
return s.cache.Get(key, value)
}
// Delete deletes an item from the cache
func (s *CacheService) Delete(ctx context.Context, key string) {
s.cache.Delete(key)
}
// Clear deletes all items from the cache
func (s *CacheService) Clear(ctx context.Context) {
s.cache.Clear()
}
// Close closes the cache
func (s *CacheService) Close() {
s.cache.Close()
}

View file

@ -1,144 +0,0 @@
package cmd
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/spf13/viper"
"otpm/api"
"otpm/config"
"otpm/database"
"otpm/handlers"
"otpm/models"
"otpm/server"
"otpm/services"
)
func init() {
// Set config file with multi-environment support
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
// Set environment specific config (e.g. config.production.yaml)
env := os.Getenv("OTPM_ENV")
if env != "" {
viper.SetConfigName(fmt.Sprintf("config.%s", env))
}
// Set default values
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.timeout.read", "15s")
viper.SetDefault("server.timeout.write", "15s")
viper.SetDefault("server.timeout.idle", "60s")
viper.SetDefault("database.max_open_conns", 25)
viper.SetDefault("database.max_idle_conns", 5)
viper.SetDefault("database.conn_max_lifetime", "5m")
// Set environment variable prefix
viper.SetEnvPrefix("OTPM")
viper.AutomaticEnv()
// Bind environment variables
viper.BindEnv("database.url", "OTPM_DB_URL")
viper.BindEnv("database.password", "OTPM_DB_PASSWORD")
}
// Execute is the entry point for the application
func Execute() error {
// Load configuration
cfg, err := config.LoadConfig()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
// Create context with cancellation
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Printf("Received signal: %v", sig)
cancel()
}()
// Initialize database
db, err := database.New(&cfg.Database)
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
defer db.Close()
// Run database migrations
if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
// Initialize repositories
userRepo := models.NewUserRepository(db.DB)
otpRepo := models.NewOTPRepository(db.DB)
// Initialize services
authService := services.NewAuthService(cfg, userRepo)
otpService := services.NewOTPService(otpRepo)
// Register custom validations
api.RegisterCustomValidations()
// Initialize handlers
authHandler := handlers.NewAuthHandler(authService)
otpHandler := handlers.NewOTPHandler(otpService)
// Create and configure server
srv := server.New(cfg)
// Register health check endpoint
srv.RegisterHealthCheck()
// Register public routes with type conversion
authRoutes := make(map[string]http.Handler)
for path, handler := range authHandler.Routes() {
authRoutes[path] = http.HandlerFunc(handler)
}
srv.RegisterRoutes(authRoutes)
// Register authenticated routes with type conversion
otpRoutes := make(map[string]http.Handler)
for path, handler := range otpHandler.Routes() {
otpRoutes[path] = http.HandlerFunc(handler)
}
srv.RegisterAuthRoutes(otpRoutes)
// Start server in goroutine
serverErr := make(chan error, 1)
go func() {
log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port)
if err := srv.Start(); err != nil {
serverErr <- fmt.Errorf("server error: %w", err)
}
}()
// Wait for shutdown signal or server error
select {
case err := <-serverErr:
return err
case <-ctx.Done():
// Graceful shutdown with timeout
log.Println("Shutting down server...")
if err := srv.Shutdown(); err != nil {
return fmt.Errorf("server shutdown error: %w", err)
}
log.Println("Server stopped gracefully")
}
return nil
}

View file

@ -1,23 +1,44 @@
# Server Configuration
server:
port: 8080
read_timeout: 15s
write_timeout: 15s
shutdown_timeout: 5s
timeout: 30s
# Database Configuration
database:
driver: sqlite3
dsn: otpm.sqlite
max_open_conns: 25
max_idle_conns: 25
max_lifetime: 5m
skip_migration: false
driver: "sqlite3" # or "postgres"
sqlite:
path: "./data.db"
postgres:
host: "localhost"
port: 5432
user: "postgres"
password: "password"
dbname: "otpdb"
sslmode: "disable"
jwt:
secret: "your-jwt-secret-key-change-this-in-production"
expire_delta: 24h
refresh_delta: 168h
signing_method: HS256
# Security Configuration
security:
encryption_key: "your-32-byte-encryption-key-here"
jwt_signing_key: "your-jwt-signing-key-here"
token_expiry: 24h
refresh_token_expiry: 168h # 7 days
# WeChat Configuration
wechat:
app_id: "your-wechat-app-id"
app_secret: "your-wechat-app-secret"
app_id: "YOUR_APPID"
app_secret: "YOUR_APPSECRET"
# CORS Configuration
cors:
allowed_origins:
- "http://localhost:8080"
- "https://yourdomain.com"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "Authorization"
- "Content-Type"

View file

@ -7,122 +7,156 @@ import (
"github.com/spf13/viper"
)
// Config holds all configuration for the application
// Config holds all configuration for our application
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
JWT JWTConfig `mapstructure:"jwt"`
WeChat WeChatConfig `mapstructure:"wechat"`
Server ServerConfig
Database DatabaseConfig
Security SecurityConfig
CORS CORSConfig
Wechat WechatConfig
}
// ServerConfig holds all server related configuration
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout
Port int
Timeout time.Duration
}
// DatabaseConfig holds all database related configuration
type DatabaseConfig struct {
Driver string `mapstructure:"driver"`
DSN string `mapstructure:"dsn"`
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
MaxLifetime time.Duration `mapstructure:"max_lifetime"`
SkipMigration bool `mapstructure:"skip_migration"`
Driver string
SQLite SQLiteConfig
Postgres PostgresConfig
}
// JWTConfig holds all JWT related configuration
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireDelta time.Duration `mapstructure:"expire_delta"`
RefreshDelta time.Duration `mapstructure:"refresh_delta"`
SigningMethod string `mapstructure:"signing_method"`
Issuer string `mapstructure:"issuer"`
Audience string `mapstructure:"audience"`
// SQLiteConfig holds SQLite specific configuration
type SQLiteConfig struct {
Path string
}
// WeChatConfig holds all WeChat related configuration
type WeChatConfig struct {
// PostgresConfig holds PostgreSQL specific configuration
type PostgresConfig struct {
Host string
Port int
User string
Password string
DBName string
SSLMode string
}
// SecurityConfig holds all security related configuration
type SecurityConfig struct {
EncryptionKey string
JWTSigningKey string
TokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}
// CORSConfig holds CORS related configuration
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
}
// WechatConfig holds WeChat related configuration
type WechatConfig struct {
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
}
// LoadConfig loads the configuration from file and environment variables
func LoadConfig() (*Config, error) {
// Set default values
setDefaults()
// LoadConfig reads configuration from file or environment variables
func LoadConfig(configPath string) (*Config, error) {
v := viper.New()
// Read config file
if err := viper.ReadInConfig(); err != nil {
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(configPath)
v.AddConfigPath(".")
// Read environment variables
v.AutomaticEnv()
// Allow environment variables to override config file
v.SetEnvPrefix("OTP")
v.BindEnv("security.encryption_key", "OTP_ENCRYPTION_KEY")
v.BindEnv("security.jwt_signing_key", "OTP_JWT_SIGNING_KEY")
v.BindEnv("database.postgres.password", "OTP_DB_PASSWORD")
v.BindEnv("wechat.app_id", "OTP_WECHAT_APPID")
v.BindEnv("wechat.app_secret", "OTP_WECHAT_SECRET")
if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
if err := v.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Validate config
// Validate required configurations
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
return nil, fmt.Errorf("config validation failed: %w", err)
}
return &config, nil
}
// setDefaults sets default values for configuration
func setDefaults() {
// Server defaults
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.read_timeout", "15s")
viper.SetDefault("server.write_timeout", "15s")
viper.SetDefault("server.shutdown_timeout", "5s")
viper.SetDefault("server.timeout", "30s") // Default request processing timeout
// Database defaults
viper.SetDefault("database.driver", "sqlite3")
viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
viper.SetDefault("database.skip_migration", false)
// JWT defaults
viper.SetDefault("jwt.expire_delta", "24h")
viper.SetDefault("jwt.refresh_delta", "168h") // 7 days
viper.SetDefault("jwt.signing_method", "HS256")
viper.SetDefault("jwt.issuer", "otpm")
viper.SetDefault("jwt.audience", "otpm-client")
}
// validateConfig validates the configuration
// validateConfig ensures all required configuration values are provided
func validateConfig(config *Config) error {
if config.Server.Port < 1 || config.Server.Port > 65535 {
return fmt.Errorf("invalid port number: %d", config.Server.Port)
if config.Security.EncryptionKey == "" {
return fmt.Errorf("encryption key is required")
}
if config.Security.JWTSigningKey == "" {
return fmt.Errorf("JWT signing key is required")
}
if config.Database.Driver == "" {
return fmt.Errorf("database driver is required")
}
if config.Database.DSN == "" {
return fmt.Errorf("database DSN is required")
switch config.Database.Driver {
case "sqlite3":
if config.Database.SQLite.Path == "" {
return fmt.Errorf("SQLite database path is required")
}
case "postgres":
if config.Database.Postgres.Host == "" ||
config.Database.Postgres.User == "" ||
config.Database.Postgres.Password == "" ||
config.Database.Postgres.DBName == "" {
return fmt.Errorf("incomplete PostgreSQL configuration")
}
default:
return fmt.Errorf("unsupported database driver: %s", config.Database.Driver)
}
if config.JWT.Secret == "" {
return fmt.Errorf("JWT secret is required")
}
if config.WeChat.AppID == "" {
// Validate WeChat configuration
if config.Wechat.AppID == "" {
return fmt.Errorf("WeChat AppID is required")
}
if config.WeChat.AppSecret == "" {
if config.Wechat.AppSecret == "" {
return fmt.Errorf("WeChat AppSecret is required")
}
return nil
}
// GetDSN returns the appropriate database connection string based on the driver
func (c *DatabaseConfig) GetDSN() string {
switch c.Driver {
case "sqlite3":
return c.SQLite.Path
case "postgres":
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
c.Postgres.Host,
c.Postgres.Port,
c.Postgres.User,
c.Postgres.Password,
c.Postgres.DBName,
c.Postgres.SSLMode)
default:
return ""
}
}

View file

@ -1,212 +0,0 @@
package database
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"otpm/config"
"github.com/jmoiron/sqlx"
)
// DB wraps sqlx.DB to provide additional functionality
type DB struct {
*sqlx.DB
}
// New creates a new database connection
func New(cfg *config.DatabaseConfig) (*DB, error) {
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Configure connection pool based on database type
if cfg.Driver == "sqlite3" {
// SQLite is a file-based database - simpler connection settings
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0) // Connections don't need to be recycled
db.SetConnMaxIdleTime(0)
} else {
// For other databases (MySQL, PostgreSQL etc.)
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
db.SetConnMaxLifetime(30 * time.Minute)
db.SetConnMaxIdleTime(5 * time.Minute)
}
// Verify connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Println("Successfully connected to database")
return &DB{db}, nil
}
// WithTx executes a function within a transaction with retry logic
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
var maxRetries int
var lastErr error
// Adjust retry settings based on database type
if db.DriverName() == "sqlite3" {
maxRetries = 5 // SQLite needs more retries due to busy timeouts
} else {
maxRetries = 3
}
// Default transaction options
opts := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
}
for attempt := 1; attempt <= maxRetries; attempt++ {
start := time.Now()
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
continue
}
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
}
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
}
// Log long-running transactions
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
log.Printf("Transaction completed in %v", elapsed)
}
if err := tx.Commit(); err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
return nil
}
return lastErr
}
// isRetryableError checks if an error is likely to succeed on retry
func isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "deadlock") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "try again") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "busy") ||
strings.Contains(errStr, "locked")
}
// ExecContext executes a query with adaptive timeout
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "insert") ||
strings.Contains(strings.ToLower(query), "update") ||
strings.Contains(strings.ToLower(query), "delete") {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
_, err := db.DB.ExecContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
}
if err != nil {
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return nil
}
// QueryRowContext executes a query that returns a single row with timeout
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return db.DB.QueryRowxContext(ctx, query, args...)
}
// QueryContext executes a query that returns multiple rows with adaptive timeout
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "join") ||
strings.Contains(strings.ToLower(query), "group by") ||
strings.Contains(strings.ToLower(query), "order by") {
timeout = 15 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
rows, err := db.DB.QueryxContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
}
if err != nil {
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return rows, nil
}
// Close closes the database connection
func (db *DB) Close() error {
if err := db.DB.Close(); err != nil {
return fmt.Errorf("failed to close database connection: %w", err)
}
return nil
}

View file

@ -1,26 +0,0 @@
CREATE TABLE IF NOT EXISTS otp (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id VARCHAR(255) NOT NULL,
openid VARCHAR(255) NOT NULL,
name VARCHAR(100) NOT NULL,
issuer VARCHAR(255),
secret VARCHAR(255) NOT NULL,
algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
period INTEGER NOT NULL DEFAULT 30,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, name),
UNIQUE(openid)
);
-- Add index for faster lookups
CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
-- Trigger to update the updated_at timestamp
CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
AFTER UPDATE ON otp
BEGIN
UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END;

View file

@ -1,6 +0,0 @@
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
);
CREATE UNIQUE INDEX idx_users_openid ON users(openid);

View file

@ -1,160 +0,0 @@
package database
import (
"context"
"fmt"
"log"
"time"
_ "embed"
"github.com/jmoiron/sqlx"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
// Migration represents a database migration
type Migration struct {
Name string
SQL string
Version int
}
// Migrations is a list of all migrations
var Migrations = []Migration{
{
Name: "Create users table",
SQL: userTable,
Version: 1,
},
{
Name: "Create OTP table",
SQL: otpTable,
Version: 2,
},
}
// MigrationRecord represents a record in the migrations table
type MigrationRecord struct {
ID int `db:"id"`
Version int `db:"version"`
Name string `db:"name"`
AppliedAt time.Time `db:"applied_at"`
}
// ensureMigrationsTable ensures that the migrations table exists
func ensureMigrationsTable(db *sqlx.DB) error {
query := `
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version INTEGER NOT NULL,
name TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
return nil
}
// getAppliedMigrations gets all applied migrations
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
var records []MigrationRecord
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
result := make(map[int]MigrationRecord)
for _, record := range records {
result[record.Version] = record
}
return result, nil
}
// Migrate runs all pending migrations
func Migrate(db *sqlx.DB, skipMigration bool) error {
if skipMigration {
log.Println("Skipping database migration as configured")
return nil
}
// Ensure migrations table exists
if err := ensureMigrationsTable(db); err != nil {
return err
}
// Get applied migrations
appliedMigrations, err := getAppliedMigrations(db)
if err != nil {
return err
}
// Apply pending migrations
for _, migration := range Migrations {
if _, ok := appliedMigrations[migration.Version]; ok {
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
continue
}
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
// Start a transaction for this migration
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// Execute migration
if _, err := tx.Exec(migration.SQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Record migration
if _, err := tx.Exec(
"INSERT INTO migrations (version, name) VALUES (?, ?)",
migration.Version, migration.Name,
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
}
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
}
return nil
}
// MigrateWithContext runs all pending migrations with context
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
// Create a channel to signal completion
done := make(chan error, 1)
// Run migration in a goroutine
go func() {
done <- Migrate(db, skipMigration)
}()
// Wait for migration to complete or context to be canceled
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("migration canceled: %w", ctx.Err())
}
}

View file

@ -1,663 +0,0 @@
// Package docs provides API documentation using Swagger/OpenAPI
package docs
import (
"encoding/json"
"net/http"
)
// SwaggerInfo holds the API information
var SwaggerInfo = struct {
Title string
Description string
Version string
Host string
BasePath string
Schemes []string
}{
Title: "OTPM API",
Description: "API for One-Time Password Manager",
Version: "1.0.0",
Host: "localhost:8080",
BasePath: "/",
Schemes: []string{"http", "https"},
}
// SwaggerJSON returns the Swagger JSON
func SwaggerJSON() []byte {
swagger := map[string]interface{}{
"swagger": "2.0",
"info": map[string]interface{}{
"title": SwaggerInfo.Title,
"description": SwaggerInfo.Description,
"version": SwaggerInfo.Version,
},
"host": SwaggerInfo.Host,
"basePath": SwaggerInfo.BasePath,
"schemes": SwaggerInfo.Schemes,
"paths": getPaths(),
"definitions": map[string]interface{}{
"LoginRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "WeChat authorization code",
},
},
"required": []string{"code"},
},
"LoginResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"token": map[string]interface{}{
"type": "string",
"description": "JWT token",
},
"openid": map[string]interface{}{
"type": "string",
"description": "WeChat OpenID",
},
},
},
"CreateOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"secret": map[string]interface{}{
"type": "string",
"description": "OTP secret",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
"enum": []string{"SHA1", "SHA256", "SHA512"},
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
"enum": []int{6, 8},
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
"enum": []int{30, 60},
},
},
"required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"},
},
"OTP": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "string",
"description": "OTP ID",
},
"user_id": map[string]interface{}{
"type": "string",
"description": "User ID",
},
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
},
"created_at": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Creation time",
},
"updated_at": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Last update time",
},
},
},
"OTPCodeResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "OTP code",
},
"expires_in": map[string]interface{}{
"type": "integer",
"description": "Seconds until expiration",
},
},
},
"VerifyOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "OTP code to verify",
},
},
"required": []string{"code"},
},
"VerifyOTPResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"valid": map[string]interface{}{
"type": "boolean",
"description": "Whether the code is valid",
},
},
},
"UpdateOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
"enum": []string{"SHA1", "SHA256", "SHA512"},
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
"enum": []int{6, 8},
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
"enum": []int{30, 60},
},
},
},
"ErrorResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "integer",
"description": "Error code",
},
"message": map[string]interface{}{
"type": "string",
"description": "Error message",
},
},
},
},
"securityDefinitions": map[string]interface{}{
"Bearer": map[string]interface{}{
"type": "apiKey",
"name": "Authorization",
"in": "header",
"description": "JWT token with Bearer prefix",
},
},
}
data, _ := json.MarshalIndent(swagger, "", " ")
return data
}
// getPaths returns the API paths
func getPaths() map[string]interface{} {
return map[string]interface{}{
"/login": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Login with WeChat",
"description": "Login with WeChat authorization code",
"tags": []string{"auth"},
"parameters": []map[string]interface{}{
{
"name": "body",
"in": "body",
"description": "Login request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/LoginRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Successful login",
"schema": map[string]interface{}{
"$ref": "#/definitions/LoginResponse",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/verify-token": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Verify token",
"description": "Verify JWT token",
"tags": []string{"auth"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Token is valid",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"valid": map[string]interface{}{
"type": "boolean",
"description": "Whether the token is valid",
},
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp": map[string]interface{}{
"get": map[string]interface{}{
"summary": "List OTPs",
"description": "List all OTPs for the authenticated user",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "List of OTPs",
"schema": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
"post": map[string]interface{}{
"summary": "Create OTP",
"description": "Create a new OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "body",
"in": "body",
"description": "OTP creation request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/CreateOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP created",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}": map[string]interface{}{
"put": map[string]interface{}{
"summary": "Update OTP",
"description": "Update an existing OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
{
"name": "body",
"in": "body",
"description": "OTP update request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/UpdateOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP updated",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
"delete": map[string]interface{}{
"summary": "Delete OTP",
"description": "Delete an existing OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP deleted",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "Success message",
},
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}/code": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Get OTP code",
"description": "Get the current OTP code",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP code",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTPCodeResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}/verify": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Verify OTP code",
"description": "Verify an OTP code",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
{
"name": "body",
"in": "body",
"description": "OTP verification request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/VerifyOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP verification result",
"schema": map[string]interface{}{
"$ref": "#/definitions/VerifyOTPResponse",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/health": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Health check",
"description": "Check if the API is healthy",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "API is healthy",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"status": map[string]interface{}{
"type": "string",
"description": "Health status",
},
"time": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Current time",
},
},
},
},
},
},
},
"/metrics": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Metrics",
"description": "Get application metrics",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Application metrics",
},
},
},
},
"/swagger.json": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Swagger JSON",
"description": "Get Swagger JSON",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Swagger JSON",
},
},
},
},
}
}
// Handler returns an HTTP handler for Swagger JSON
func Handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(SwaggerJSON())
}
}

50
go.mod
View file

@ -1,50 +0,0 @@
module otpm
go 1.23.0
toolchain go1.23.9
require (
github.com/go-playground/validator/v10 v10.26.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/julienschmidt/httprouter v1.3.0
github.com/prometheus/client_golang v1.22.0
github.com/spf13/viper v1.19.0
golang.org/x/crypto v0.38.0
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

124
go.sum
View file

@ -1,124 +0,0 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -1,156 +0,0 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"otpm/api"
"otpm/services"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
)
// AuthHandler handles authentication related requests
type AuthHandler struct {
authService *services.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required,min=32,max=128"`
}
// LoginResponse represents a login response
type LoginResponse struct {
Token string `json:"token"`
OpenID string `json:"openid"`
}
// TokenRequest represents a token verification request
type TokenRequest struct {
Token string `validate:"required,min=32"`
}
// Login handles WeChat login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Limit request body size to prevent DOS
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
// Parse and validate request
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("Login request parse error: %v", err)
return
}
// Validate using validator
if err := api.Validate.Struct(req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request parameters: %v", err))
log.Printf("Login request validation failed: %v", err)
return
}
// Login with WeChat code
token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("Login failed for code %s: %v", req.Code, err)
return
}
// Log successful login
log.Printf("Login successful for code %s (took %v)",
req.Code, time.Since(start))
// Return token
api.NewResponseWriter(w).WriteSuccess(LoginResponse{
Token: token,
})
}
// VerifyToken handles token verification
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Authorization header is required")
log.Printf("Token verification failed: missing Authorization header")
return
}
// Validate token format
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token format. Expected 'Bearer <token>'")
log.Printf("Token verification failed: invalid token format")
return
}
token := authHeader[7:]
// Validate token using validator
tokenReq := TokenRequest{Token: token}
if err := api.Validate.Struct(tokenReq); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token format")
log.Printf("Token verification failed: %v", err)
return
}
// Validate token
claims, err := h.authService.ValidateToken(token)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("Token verification failed for token %s: %v",
maskToken(token), err) // Mask token in logs
return
}
// Log successful verification
userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
if !ok {
log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
} else {
log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
}
// Token is valid
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": true,
})
}
// maskToken masks sensitive parts of token for logging
func maskToken(token string) string {
if len(token) < 8 {
return "****"
}
return token[:4] + "****" + token[len(token)-4:]
}
// Routes returns all routes for the auth handler
func (h *AuthHandler) Routes() map[string]httprouter.Handle {
return map[string]httprouter.Handle{
"/api/login": h.Login,
"/api/verify-token": h.VerifyToken,
}
}

View file

@ -1,114 +0,0 @@
package handlers
import (
"encoding/json"
"net/http"
"github.com/julienschmidt/httprouter"
"otpm/api"
"otpm/middleware"
"otpm/models"
"otpm/services"
)
// OTPHandler handles OTP-related HTTP requests
type OTPHandler struct {
otpService *services.OTPService
}
// NewOTPHandler creates a new OTPHandler
func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
return &OTPHandler{
otpService: otpService,
}
}
// Routes returns the routes for OTP operations
func (h *OTPHandler) Routes() map[string]httprouter.Handle {
return map[string]httprouter.Handle{
"POST /api/otp": h.CreateOTP,
"GET /api/otps": h.ListOTPs,
"GET /api/otp/:id": h.GetOTP,
}
}
// CreateOTP handles the creation of a new OTP
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Parse request body
var params models.OTPParams
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
return
}
// Validate request
if err := api.Validate.Struct(params); err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
return
}
// Create OTP
otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}
// ListOTPs handles listing all OTPs for a user
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTPs
otps, err := h.otpService.ListOTPs(r.Context(), userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otps)
}
// GetOTP handles getting a specific OTP
func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Get user ID from context
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := ps.ByName("id")
if otpID == "" {
api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
return
}
// Get OTP
otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}

51
init/postgresql/init.sql Normal file
View file

@ -0,0 +1,51 @@
-- 创建tokens表
CREATE TABLE IF NOT EXISTS tokens (
id VARCHAR(255) NOT NULL, -- token的唯一标识符
user_id VARCHAR(255) NOT NULL, -- 用户ID
issuer VARCHAR(255) NOT NULL, -- 令牌发行者
account VARCHAR(255) NOT NULL, -- 账户名称
secret TEXT NOT NULL, -- 密钥
type VARCHAR(10) NOT NULL, -- 令牌类型totp/hotp
counter INTEGER, -- HOTP计数器可选
period INTEGER NOT NULL, -- TOTP周期
digits INTEGER NOT NULL, -- 验证码位数
algo VARCHAR(10) NOT NULL, -- 使用的哈希算法
timestamp BIGINT NOT NULL, -- 最后更新时间戳
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, user_id)
);
-- 创建更新时间戳的触发器
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
CREATE TRIGGER update_tokens_updated_at
BEFORE UPDATE ON tokens
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
-- 创建索引
CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
CREATE INDEX IF NOT EXISTS idx_tokens_timestamp ON tokens(timestamp);
-- 添加注释
COMMENT ON TABLE tokens IS 'OTP令牌数据表';
COMMENT ON COLUMN tokens.id IS '令牌的唯一标识符';
COMMENT ON COLUMN tokens.user_id IS '用户ID';
COMMENT ON COLUMN tokens.issuer IS '令牌发行者';
COMMENT ON COLUMN tokens.account IS '账户名称';
COMMENT ON COLUMN tokens.secret IS '密钥';
COMMENT ON COLUMN tokens.type IS '令牌类型totp/hotp';
COMMENT ON COLUMN tokens.counter IS 'HOTP计数器可选';
COMMENT ON COLUMN tokens.period IS 'TOTP周期';
COMMENT ON COLUMN tokens.digits IS '验证码位数';
COMMENT ON COLUMN tokens.algo IS '使用的哈希算法';
COMMENT ON COLUMN tokens.timestamp IS '最后更新时间戳';
COMMENT ON COLUMN tokens.created_at IS '创建时间';
COMMENT ON COLUMN tokens.updated_at IS '最后更新时间';

50
init/sqlite3/init.sql Normal file
View file

@ -0,0 +1,50 @@
-- SQLite3 initialization SQL
-- Enable WAL mode for better concurrency (simple performance boost)
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
-- Enable foreign key support
PRAGMA foreign_keys = ON;
-- 创建tokens表
CREATE TABLE IF NOT EXISTS tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
issuer TEXT NOT NULL,
account TEXT NOT NULL,
secret TEXT NOT NULL CHECK (length(secret) >= 16 AND secret REGEXP '^[A-Z2-7]+=*$'),
type TEXT NOT NULL CHECK (type IN ('HOTP', 'TOTP')),
counter INTEGER CHECK (
(type = 'HOTP' AND counter >= 0) OR
(type = 'TOTP' AND counter IS NULL)
),
period INTEGER DEFAULT 30 CHECK (
(type = 'TOTP' AND period >= 30) OR
(type = 'HOTP' AND period IS NULL)
),
digits INTEGER NOT NULL DEFAULT 6 CHECK (digits IN (6, 8)),
algo TEXT NOT NULL DEFAULT 'SHA1' CHECK (algo IN ('SHA1', 'SHA256', 'SHA512')),
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
updated_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
UNIQUE(user_id, issuer, account)
);
-- 基本索引
CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
CREATE INDEX IF NOT EXISTS idx_tokens_lookup ON tokens(user_id, issuer, account);
CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'HOTP';
CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'TOTP';
-- 简化统计视图
CREATE VIEW IF NOT EXISTS v_token_stats AS
SELECT
user_id,
COUNT(*) as total_tokens,
SUM(type = 'HOTP') as hotp_count,
SUM(type = 'TOTP') as totp_count
FROM tokens
GROUP BY user_id;
-- 设置版本号
PRAGMA user_version = 1;

View file

@ -1,204 +0,0 @@
package logger
import (
"context"
"fmt"
"io"
"os"
"runtime"
"strings"
"time"
"github.com/google/uuid"
)
// Level represents a log level
type Level int
const (
// DEBUG level
DEBUG Level = iota
// INFO level
INFO
// WARN level
WARN
// ERROR level
ERROR
// FATAL level
FATAL
)
// String returns the string representation of the log level
func (l Level) String() string {
switch l {
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
case FATAL:
return "FATAL"
default:
return "UNKNOWN"
}
}
// Logger represents a logger
type Logger struct {
level Level
output io.Writer
}
// contextKey is a type for context keys
type contextKey string
// requestIDKey is the key for request ID in context
const requestIDKey = contextKey("request_id")
// New creates a new logger
func New(level Level, output io.Writer) *Logger {
if output == nil {
output = os.Stdout
}
return &Logger{
level: level,
output: output,
}
}
// WithLevel creates a new logger with the specified level
func (l *Logger) WithLevel(level Level) *Logger {
return &Logger{
level: level,
output: l.output,
}
}
// WithOutput creates a new logger with the specified output
func (l *Logger) WithOutput(output io.Writer) *Logger {
return &Logger{
level: l.level,
output: output,
}
}
// log logs a message with the specified level
func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) {
if level < l.level {
return
}
// Get request ID from context
requestID := getRequestID(ctx)
// Get caller information
_, file, line, ok := runtime.Caller(2)
if !ok {
file = "unknown"
line = 0
}
// Extract just the filename
if idx := strings.LastIndex(file, "/"); idx >= 0 {
file = file[idx+1:]
}
// Format message
message := fmt.Sprintf(format, args...)
// Format log entry
timestamp := time.Now().Format(time.RFC3339)
logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n",
timestamp, level.String(), file, line, requestID, message)
// Write log entry
_, _ = l.output.Write([]byte(logEntry))
// Exit if fatal
if level == FATAL {
os.Exit(1)
}
}
// Debug logs a debug message
func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, DEBUG, format, args...)
}
// Info logs an info message
func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, INFO, format, args...)
}
// Warn logs a warning message
func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, WARN, format, args...)
}
// Error logs an error message
func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, ERROR, format, args...)
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, FATAL, format, args...)
}
// WithRequestID adds a request ID to the context
func WithRequestID(ctx context.Context) context.Context {
requestID := uuid.New().String()
return context.WithValue(ctx, requestIDKey, requestID)
}
// GetRequestID gets the request ID from the context
func GetRequestID(ctx context.Context) string {
return getRequestID(ctx)
}
// getRequestID gets the request ID from the context
func getRequestID(ctx context.Context) string {
if ctx == nil {
return "-"
}
requestID, ok := ctx.Value(requestIDKey).(string)
if !ok {
return "-"
}
return requestID
}
// Default logger
var defaultLogger = New(INFO, os.Stdout)
// SetDefaultLogger sets the default logger
func SetDefaultLogger(logger *Logger) {
defaultLogger = logger
}
// Debug logs a debug message using the default logger
func Debug(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Debug(ctx, format, args...)
}
// Info logs an info message using the default logger
func Info(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Info(ctx, format, args...)
}
// Warn logs a warning message using the default logger
func Warn(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Warn(ctx, format, args...)
}
// Error logs an error message using the default logger
func Error(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Error(ctx, format, args...)
}
// Fatal logs a fatal message and exits using the default logger
func Fatal(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Fatal(ctx, format, args...)
}

View file

@ -1,7 +0,0 @@
package main
import "otpm/cmd"
func main() {
cmd.Execute()
}

View file

@ -1,193 +0,0 @@
package metrics
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
// Default metrics
requestDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "Duration of HTTP requests in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
)
requestTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
)
otpGenerationTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "otp_generation_total",
Help: "Total number of OTP generations",
},
[]string{"user_id", "otp_id"},
)
otpVerificationTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "otp_verification_total",
Help: "Total number of OTP verifications",
},
[]string{"user_id", "otp_id", "success"},
)
activeUsers = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "active_users",
Help: "Number of active users",
},
)
cacheHits = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Help: "Total number of cache hits",
},
[]string{"cache"},
)
cacheMisses = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Help: "Total number of cache misses",
},
[]string{"cache"},
)
)
func init() {
// Register metrics with prometheus
prometheus.MustRegister(
requestDuration,
requestTotal,
otpGenerationTotal,
otpVerificationTotal,
activeUsers,
cacheHits,
cacheMisses,
)
}
// MetricsService provides metrics functionality
type MetricsService struct {
activeUsersMutex sync.RWMutex
activeUserIDs map[string]bool
}
// NewMetricsService creates a new MetricsService
func NewMetricsService() *MetricsService {
return &MetricsService{
activeUserIDs: make(map[string]bool),
}
}
// Handler returns an HTTP handler for metrics
func (s *MetricsService) Handler() http.Handler {
return promhttp.Handler()
}
// RecordRequest records metrics for an HTTP request
func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) {
labels := prometheus.Labels{
"method": method,
"path": path,
"status": fmt.Sprintf("%d", status),
}
requestDuration.With(labels).Observe(duration.Seconds())
requestTotal.With(labels).Inc()
}
// RecordOTPGeneration records metrics for OTP generation
func (s *MetricsService) RecordOTPGeneration(userID, otpID string) {
otpGenerationTotal.With(prometheus.Labels{
"user_id": userID,
"otp_id": otpID,
}).Inc()
}
// RecordOTPVerification records metrics for OTP verification
func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) {
otpVerificationTotal.With(prometheus.Labels{
"user_id": userID,
"otp_id": otpID,
"success": fmt.Sprintf("%t", success),
}).Inc()
}
// RecordUserActivity records user activity
func (s *MetricsService) RecordUserActivity(userID string) {
s.activeUsersMutex.Lock()
defer s.activeUsersMutex.Unlock()
if !s.activeUserIDs[userID] {
s.activeUserIDs[userID] = true
activeUsers.Inc()
}
}
// RecordUserInactivity records user inactivity
func (s *MetricsService) RecordUserInactivity(userID string) {
s.activeUsersMutex.Lock()
defer s.activeUsersMutex.Unlock()
if s.activeUserIDs[userID] {
delete(s.activeUserIDs, userID)
activeUsers.Dec()
}
}
// RecordCacheHit records a cache hit
func (s *MetricsService) RecordCacheHit(cache string) {
cacheHits.With(prometheus.Labels{
"cache": cache,
}).Inc()
}
// RecordCacheMiss records a cache miss
func (s *MetricsService) RecordCacheMiss(cache string) {
cacheMisses.With(prometheus.Labels{
"cache": cache,
}).Inc()
}
// Middleware creates a middleware that records request metrics
func (s *MetricsService) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Create response writer that captures status code
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
// Call next handler
next.ServeHTTP(rw, r)
// Record metrics
s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start))
})
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}

View file

@ -1,353 +0,0 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/golang-jwt/jwt"
)
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ErrorResponse sends a JSON error response
func ErrorResponse(w http.ResponseWriter, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(Response{
Code: code,
Message: message,
})
}
// SuccessResponse sends a JSON success response
func SuccessResponse(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
Code: http.StatusOK,
Message: "success",
Data: data,
})
}
// Logger is a middleware that logs request details with structured format
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
r.Header.Set("X-Request-ID", requestID)
}
// Create a custom response writer to capture status code
rw := &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
}
// Process request
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
// Log structured request details
log.Printf(
"method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
r.Method,
r.URL.Path,
rw.status,
time.Since(start).String(),
r.RemoteAddr,
requestID,
)
})
}
// generateRequestID creates a unique request identifier
func generateRequestID() string {
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}
// randomString generates a random string of given length
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// Recover is a middleware that recovers from panics with detailed logging
func Recover(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// Get request ID from context
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log error with stack trace and request context
log.Printf(
"panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
err,
requestID,
r.Method,
r.URL.Path,
r.RemoteAddr,
debug.Stack(),
)
// Determine error type
var message string
var status int
switch e := err.(type) {
case error:
message = e.Error()
if isClientError(e) {
status = http.StatusBadRequest
} else {
status = http.StatusInternalServerError
}
case string:
message = e
status = http.StatusInternalServerError
default:
message = "Internal Server Error"
status = http.StatusInternalServerError
}
ErrorResponse(w, status, message)
}
}()
next.ServeHTTP(w, r)
})
}
// isClientError checks if error should be treated as client error
func isClientError(err error) bool {
// Add more client error types as needed
return strings.Contains(err.Error(), "validation") ||
strings.Contains(err.Error(), "invalid") ||
strings.Contains(err.Error(), "missing")
}
// CORS is a middleware that handles CORS
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// Timeout is a middleware that safely handles request timeouts
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), duration)
defer cancel()
// Use buffered channels to prevent goroutine leaks
done := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
// Track request processing in goroutine
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(w, r.WithContext(ctx))
done <- struct{}{}
}()
// Wait for completion, timeout or panic
select {
case <-done:
return
case p := <-panicChan:
panic(p) // Re-throw panic to be caught by Recover middleware
case <-ctx.Done():
// Get request context for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log timeout details
log.Printf(
"request_timeout: request_id=%s method=%s path=%s timeout=%s",
requestID,
r.Method,
r.URL.Path,
duration.String(),
)
// Send timeout response
ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
"Request timed out after %s", duration.String(),
))
}
})
}
}
// Auth is a middleware that validates JWT tokens with enhanced security
func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get request ID for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
return
}
// Validate header format
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer <token>'")
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtSecret), nil
})
if err != nil {
log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
if !token.Valid {
log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
// Validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
return
}
// Check required claims
userID, ok := claims["user_id"].(string)
if !ok || userID == "" {
log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
return
}
// Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
return
}
}
// Check required roles if specified
if len(requiredRoles) > 0 {
roles, ok := claims["roles"].([]interface{})
if !ok {
log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
return
}
hasRequiredRole := false
for _, requiredRole := range requiredRoles {
for _, role := range roles {
if r, ok := role.(string); ok && r == requiredRole {
hasRequiredRole = true
break
}
}
}
if !hasRequiredRole {
log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
return
}
}
// Add claims to context
ctx := r.Context()
ctx = context.WithValue(ctx, "user_id", userID)
ctx = context.WithValue(ctx, "claims", claims)
log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// responseWriter is a custom response writer that captures the status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// GetUserID gets the user ID from the request context
func GetUserID(r *http.Request) (string, error) {
userID, ok := r.Context().Value("user_id").(string)
if !ok {
return "", fmt.Errorf("user ID not found in context")
}
return userID, nil
}

View file

@ -1,50 +0,0 @@
// app.js
App({
onLaunch() {
// 检查更新
if (wx.canIUse('getUpdateManager')) {
const updateManager = wx.getUpdateManager();
updateManager.onCheckForUpdate(function (res) {
if (res.hasUpdate) {
updateManager.onUpdateReady(function () {
wx.showModal({
title: '更新提示',
content: '新版本已经准备好,是否重启应用?',
success: function (res) {
if (res.confirm) {
updateManager.applyUpdate();
}
}
});
});
updateManager.onUpdateFailed(function () {
wx.showModal({
title: '更新提示',
content: '新版本下载失败,请检查网络后重试',
showCancel: false
});
});
}
});
}
// 获取系统信息
try {
const systemInfo = wx.getSystemInfoSync();
this.globalData.systemInfo = systemInfo;
// 计算安全区域
const { screenHeight, safeArea } = systemInfo;
this.globalData.safeAreaBottom = screenHeight - safeArea.bottom;
} catch (e) {
console.error('获取系统信息失败', e);
}
},
globalData: {
userInfo: null,
systemInfo: {},
safeAreaBottom: 0
}
});

View file

@ -1,23 +0,0 @@
{
"pages": [
"pages/login/login",
"pages/otp-list/index",
"pages/otp-add/index"
],
"window": {
"backgroundTextStyle": "light",
"navigationBarBackgroundColor": "#fff",
"navigationBarTitleText": "OTPM",
"navigationBarTextStyle": "black",
"backgroundColor": "#F8F8F8"
},
"permission": {
"scope.camera": {
"desc": "需要使用相机扫描二维码"
}
},
"usingComponents": {},
"style": "v2",
"sitemapLocation": "sitemap.json",
"lazyCodeLoading": "requiredComponents"
}

View file

@ -1,238 +0,0 @@
/**app.wxss**/
page {
--primary-color: #1890ff;
--danger-color: #ff4d4f;
--success-color: #52c41a;
--warning-color: #faad14;
--text-color: #333333;
--text-color-secondary: #666666;
--text-color-light: #999999;
--border-color: #e8e8e8;
--background-color: #f8f8f8;
--border-radius: 8rpx;
--safe-area-bottom: env(safe-area-inset-bottom);
font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica,
Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei',
sans-serif;
font-size: 28rpx;
line-height: 1.5;
color: var(--text-color);
background-color: var(--background-color);
}
/* 清除默认样式 */
button {
padding: 0;
margin: 0;
background: none;
border: none;
text-align: left;
line-height: inherit;
overflow: visible;
}
button::after {
border: none;
}
/* 通用样式类 */
.container {
min-height: 100vh;
box-sizing: border-box;
}
.safe-area-bottom {
padding-bottom: var(--safe-area-bottom);
}
.flex-center {
display: flex;
align-items: center;
justify-content: center;
}
.flex-between {
display: flex;
align-items: center;
justify-content: space-between;
}
.flex-column {
display: flex;
flex-direction: column;
}
.text-primary {
color: var(--primary-color);
}
.text-danger {
color: var(--danger-color);
}
.text-success {
color: var(--success-color);
}
.text-warning {
color: var(--warning-color);
}
.text-secondary {
color: var(--text-color-secondary);
}
.text-light {
color: var(--text-color-light);
}
.text-center {
text-align: center;
}
.text-left {
text-align: left;
}
.text-right {
text-align: right;
}
.text-bold {
font-weight: bold;
}
.text-ellipsis {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.bg-white {
background-color: #ffffff;
}
.bg-primary {
background-color: var(--primary-color);
}
.bg-danger {
background-color: var(--danger-color);
}
.bg-success {
background-color: var(--success-color);
}
.bg-warning {
background-color: var(--warning-color);
}
.shadow {
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.rounded {
border-radius: var(--border-radius);
}
.border {
border: 2rpx solid var(--border-color);
}
.border-top {
border-top: 2rpx solid var(--border-color);
}
.border-bottom {
border-bottom: 2rpx solid var(--border-color);
}
/* 动画类 */
.fade-in {
animation: fadeIn 0.3s ease-in-out;
}
.fade-out {
animation: fadeOut 0.3s ease-in-out;
}
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
@keyframes fadeOut {
from {
opacity: 1;
}
to {
opacity: 0;
}
}
/* 间距类 */
.m-0 { margin: 0; }
.m-1 { margin: 10rpx; }
.m-2 { margin: 20rpx; }
.m-3 { margin: 30rpx; }
.m-4 { margin: 40rpx; }
.mt-0 { margin-top: 0; }
.mt-1 { margin-top: 10rpx; }
.mt-2 { margin-top: 20rpx; }
.mt-3 { margin-top: 30rpx; }
.mt-4 { margin-top: 40rpx; }
.mb-0 { margin-bottom: 0; }
.mb-1 { margin-bottom: 10rpx; }
.mb-2 { margin-bottom: 20rpx; }
.mb-3 { margin-bottom: 30rpx; }
.mb-4 { margin-bottom: 40rpx; }
.ml-0 { margin-left: 0; }
.ml-1 { margin-left: 10rpx; }
.ml-2 { margin-left: 20rpx; }
.ml-3 { margin-left: 30rpx; }
.ml-4 { margin-left: 40rpx; }
.mr-0 { margin-right: 0; }
.mr-1 { margin-right: 10rpx; }
.mr-2 { margin-right: 20rpx; }
.mr-3 { margin-right: 30rpx; }
.mr-4 { margin-right: 40rpx; }
.p-0 { padding: 0; }
.p-1 { padding: 10rpx; }
.p-2 { padding: 20rpx; }
.p-3 { padding: 30rpx; }
.p-4 { padding: 40rpx; }
.pt-0 { padding-top: 0; }
.pt-1 { padding-top: 10rpx; }
.pt-2 { padding-top: 20rpx; }
.pt-3 { padding-top: 30rpx; }
.pt-4 { padding-top: 40rpx; }
.pb-0 { padding-bottom: 0; }
.pb-1 { padding-bottom: 10rpx; }
.pb-2 { padding-bottom: 20rpx; }
.pb-3 { padding-bottom: 30rpx; }
.pb-4 { padding-bottom: 40rpx; }
.pl-0 { padding-left: 0; }
.pl-1 { padding-left: 10rpx; }
.pl-2 { padding-left: 20rpx; }
.pl-3 { padding-left: 30rpx; }
.pl-4 { padding-left: 40rpx; }
.pr-0 { padding-right: 0; }
.pr-1 { padding-right: 10rpx; }
.pr-2 { padding-right: 20rpx; }
.pr-3 { padding-right: 30rpx; }
.pr-4 { padding-right: 40rpx; }

View file

@ -1,48 +0,0 @@
// login.js
import { wxLogin } from '../../services/auth';
Page({
data: {
loading: false
},
onLoad() {
// 页面加载时检查是否已经登录
const token = wx.getStorageSync('token');
if (token) {
this.redirectToHome();
}
},
// 处理登录按钮点击
handleLogin() {
if (this.data.loading) return;
this.setData({ loading: true });
wxLogin()
.then(() => {
wx.showToast({
title: '登录成功',
icon: 'success'
});
this.redirectToHome();
})
.catch(err => {
wx.showToast({
title: err.message || '登录失败',
icon: 'none'
});
})
.finally(() => {
this.setData({ loading: false });
});
},
// 跳转到首页
redirectToHome() {
wx.reLaunch({
url: '/pages/otp-list/index'
});
}
});

View file

@ -1,3 +0,0 @@
{
"usingComponents": {}
}

View file

@ -1,30 +0,0 @@
<!-- login.wxml -->
<view class="container">
<view class="logo-container">
<image class="logo" src="/assets/images/logo.png" mode="aspectFit"></image>
<text class="app-name">OTPM 小程序</text>
</view>
<view class="login-container">
<text class="login-title">欢迎使用 OTPM</text>
<text class="login-subtitle">一次性密码管理工具</text>
<button
class="login-button {{loading ? 'loading' : ''}}"
type="primary"
bindtap="handleLogin"
disabled="{{loading}}"
>
<text wx:if="{{!loading}}">微信一键登录</text>
<view wx:else class="loading-container">
<view class="loading-icon"></view>
<text>登录中...</text>
</view>
</button>
<view class="privacy-policy">
<text>登录即表示您同意</text>
<navigator url="/pages/privacy/index" class="policy-link">《隐私政策》</navigator>
</view>
</view>
</view>

View file

@ -1,97 +0,0 @@
/* login.wxss */
.container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-between;
height: 100vh;
padding: 60rpx 40rpx;
box-sizing: border-box;
background-color: #f8f8f8;
}
.logo-container {
display: flex;
flex-direction: column;
align-items: center;
margin-top: 80rpx;
}
.logo {
width: 180rpx;
height: 180rpx;
margin-bottom: 20rpx;
}
.app-name {
font-size: 36rpx;
font-weight: bold;
color: #333;
}
.login-container {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
margin-bottom: 100rpx;
}
.login-title {
font-size: 48rpx;
font-weight: bold;
color: #333;
margin-bottom: 20rpx;
}
.login-subtitle {
font-size: 28rpx;
color: #666;
margin-bottom: 80rpx;
}
.login-button {
width: 80%;
height: 88rpx;
border-radius: 44rpx;
font-size: 32rpx;
display: flex;
align-items: center;
justify-content: center;
}
.login-button.loading {
background-color: #8cc4ff;
}
.loading-container {
display: flex;
align-items: center;
justify-content: center;
}
.loading-icon {
width: 36rpx;
height: 36rpx;
margin-right: 10rpx;
border: 4rpx solid #ffffff;
border-radius: 50%;
border-top-color: transparent;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.privacy-policy {
margin-top: 40rpx;
font-size: 24rpx;
color: #999;
}
.policy-link {
color: #1890ff;
display: inline;
}

View file

@ -1,169 +0,0 @@
// otp-add/index.js
import { createOTP } from '../../services/otp';
Page({
data: {
form: {
name: '',
issuer: '',
secret: '',
algorithm: 'SHA1',
digits: 6,
period: 30
},
algorithms: ['SHA1', 'SHA256', 'SHA512'],
digitOptions: [6, 8],
periodOptions: [30, 60],
submitting: false,
scanMode: false
},
// 处理输入变化
handleInputChange(e) {
const { field } = e.currentTarget.dataset;
const { value } = e.detail;
this.setData({
[`form.${field}`]: value
});
},
// 处理选择器变化
handlePickerChange(e) {
const { field } = e.currentTarget.dataset;
const { value } = e.detail;
const options = this.data[`${field}Options`] || this.data[field];
const selectedValue = options[value];
this.setData({
[`form.${field}`]: selectedValue
});
},
// 扫描二维码
handleScanQRCode() {
this.setData({ scanMode: true });
wx.scanCode({
scanType: ['qrCode'],
success: (res) => {
try {
// 解析otpauth://协议的URL
const url = res.result;
if (url.startsWith('otpauth://totp/')) {
const parsedUrl = new URL(url);
const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠
// 解析路径中的issuer和name
let issuer = '';
let name = path;
if (path.includes(':')) {
const parts = path.split(':');
issuer = parts[0];
name = parts[1];
}
// 从查询参数中获取其他信息
const secret = parsedUrl.searchParams.get('secret') || '';
const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1';
const digits = parseInt(parsedUrl.searchParams.get('digits') || '6');
const period = parseInt(parsedUrl.searchParams.get('period') || '30');
// 如果查询参数中有issuer优先使用
if (parsedUrl.searchParams.get('issuer')) {
issuer = parsedUrl.searchParams.get('issuer');
}
this.setData({
form: {
name,
issuer,
secret,
algorithm,
digits,
period
}
});
wx.showToast({
title: '二维码解析成功',
icon: 'success'
});
} else {
wx.showToast({
title: '不支持的二维码格式',
icon: 'none'
});
}
} catch (err) {
wx.showToast({
title: '二维码解析失败',
icon: 'none'
});
}
},
fail: () => {
wx.showToast({
title: '扫描取消',
icon: 'none'
});
},
complete: () => {
this.setData({ scanMode: false });
}
});
},
// 提交表单
handleSubmit() {
const { form } = this.data;
// 表单验证
if (!form.name) {
wx.showToast({
title: '请输入名称',
icon: 'none'
});
return;
}
if (!form.secret) {
wx.showToast({
title: '请输入密钥',
icon: 'none'
});
return;
}
this.setData({ submitting: true });
createOTP(form)
.then(() => {
wx.showToast({
title: '添加成功',
icon: 'success'
});
// 返回上一页
setTimeout(() => {
wx.navigateBack();
}, 1500);
})
.catch(err => {
wx.showToast({
title: err.message || '添加失败',
icon: 'none'
});
})
.finally(() => {
this.setData({ submitting: false });
});
},
// 取消
handleCancel() {
wx.navigateBack();
}
});

View file

@ -1,3 +0,0 @@
{
"usingComponents": {}
}

View file

@ -1,119 +0,0 @@
<!-- otp-add/index.wxml -->
<view class="container">
<view class="header">
<text class="title">添加OTP</text>
</view>
<view class="form-container">
<view class="form-group">
<text class="form-label">名称 <text class="required">*</text></text>
<input
class="form-input"
placeholder="请输入OTP名称"
value="{{form.name}}"
bindinput="handleInputChange"
data-field="name"
/>
</view>
<view class="form-group">
<text class="form-label">发行方</text>
<input
class="form-input"
placeholder="请输入发行方名称"
value="{{form.issuer}}"
bindinput="handleInputChange"
data-field="issuer"
/>
</view>
<view class="form-group">
<text class="form-label">密钥 <text class="required">*</text></text>
<view class="secret-input-container">
<input
class="form-input"
placeholder="请输入密钥或扫描二维码"
value="{{form.secret}}"
bindinput="handleInputChange"
data-field="secret"
/>
<view class="scan-button" bindtap="handleScanQRCode" wx:if="{{!scanMode}}">
<text class="scan-icon">🔍</text>
</view>
<view class="scanning-indicator" wx:else>
<view class="scanning-spinner"></view>
</view>
</view>
</view>
<view class="form-group">
<text class="form-label">算法</text>
<picker
mode="selector"
range="{{algorithms}}"
value="{{algorithms.indexOf(form.algorithm)}}"
bindchange="handlePickerChange"
data-field="algorithm"
>
<view class="picker-view">
<text>{{form.algorithm}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
<view class="form-row">
<view class="form-group half">
<text class="form-label">位数</text>
<picker
mode="selector"
range="{{digitOptions}}"
value="{{digitOptions.indexOf(form.digits)}}"
bindchange="handlePickerChange"
data-field="digits"
>
<view class="picker-view">
<text>{{form.digits}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
<view class="form-group half">
<text class="form-label">周期(秒)</text>
<picker
mode="selector"
range="{{periodOptions}}"
value="{{periodOptions.indexOf(form.period)}}"
bindchange="handlePickerChange"
data-field="period"
>
<view class="picker-view">
<text>{{form.period}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
</view>
</view>
<view class="button-group">
<button
class="cancel-button"
bindtap="handleCancel"
disabled="{{submitting}}"
>取消</button>
<button
class="submit-button {{submitting ? 'loading' : ''}}"
bindtap="handleSubmit"
disabled="{{submitting}}"
>
<text wx:if="{{!submitting}}">保存</text>
<view wx:else class="loading-container">
<view class="loading-icon"></view>
<text>保存中...</text>
</view>
</button>
</view>
</view>

View file

@ -1,176 +0,0 @@
/* otp-add/index.wxss */
.container {
min-height: 100vh;
background-color: #f8f8f8;
padding: 0 0 40rpx 0;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 40rpx 32rpx;
background-color: #ffffff;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
}
.title {
font-size: 36rpx;
font-weight: bold;
color: #333333;
}
.form-container {
background-color: #ffffff;
padding: 32rpx;
margin: 32rpx;
border-radius: 16rpx;
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.form-group {
margin-bottom: 32rpx;
}
.form-row {
display: flex;
justify-content: space-between;
}
.form-group.half {
width: 48%;
}
.form-label {
display: block;
font-size: 28rpx;
color: #666666;
margin-bottom: 12rpx;
}
.required {
color: #ff4d4f;
}
.form-input {
width: 100%;
height: 80rpx;
background-color: #f5f5f5;
border-radius: 8rpx;
padding: 0 24rpx;
font-size: 28rpx;
color: #333333;
box-sizing: border-box;
}
.secret-input-container {
position: relative;
}
.scan-button {
position: absolute;
right: 20rpx;
top: 50%;
transform: translateY(-50%);
width: 60rpx;
height: 60rpx;
display: flex;
align-items: center;
justify-content: center;
}
.scan-icon {
font-size: 40rpx;
color: #1890ff;
}
.scanning-indicator {
position: absolute;
right: 20rpx;
top: 50%;
transform: translateY(-50%);
width: 40rpx;
height: 40rpx;
}
.scanning-spinner {
width: 40rpx;
height: 40rpx;
border: 4rpx solid #f3f3f3;
border-top: 4rpx solid #1890ff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.picker-view {
width: 100%;
height: 80rpx;
background-color: #f5f5f5;
border-radius: 8rpx;
padding: 0 24rpx;
font-size: 28rpx;
color: #333333;
display: flex;
align-items: center;
justify-content: space-between;
box-sizing: border-box;
}
.picker-arrow {
font-size: 24rpx;
color: #999999;
}
.button-group {
display: flex;
justify-content: space-between;
padding: 32rpx;
}
.cancel-button, .submit-button {
width: 48%;
height: 88rpx;
border-radius: 44rpx;
font-size: 32rpx;
display: flex;
align-items: center;
justify-content: center;
}
.cancel-button {
background-color: #f5f5f5;
color: #666666;
}
.submit-button {
background-color: #1890ff;
color: #ffffff;
}
.submit-button.loading {
background-color: #8cc4ff;
}
.loading-container {
display: flex;
align-items: center;
justify-content: center;
}
.loading-icon {
width: 36rpx;
height: 36rpx;
margin-right: 10rpx;
border: 4rpx solid #ffffff;
border-radius: 50%;
border-top-color: transparent;
animation: spin 1s linear infinite;
}

View file

@ -1,213 +0,0 @@
// otp-list/index.js
import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp';
import { checkLoginStatus } from '../../services/auth';
Page({
data: {
otpList: [],
loading: true,
refreshing: false
},
onLoad() {
this.checkLogin();
},
onShow() {
// 每次页面显示时刷新OTP列表
if (!this.data.loading) {
this.fetchOTPList();
}
},
// 下拉刷新
onPullDownRefresh() {
this.setData({ refreshing: true });
this.fetchOTPList().finally(() => {
wx.stopPullDownRefresh();
this.setData({ refreshing: false });
});
},
// 检查登录状态
checkLogin() {
checkLoginStatus().then(isLoggedIn => {
if (isLoggedIn) {
this.fetchOTPList();
} else {
wx.redirectTo({
url: '/pages/login/login'
});
}
});
},
// 获取OTP列表
fetchOTPList() {
this.setData({ loading: true });
return getOTPList()
.then(res => {
if (res.data && Array.isArray(res.data)) {
this.setData({
otpList: res.data,
loading: false
});
// 获取每个OTP的当前验证码
this.refreshOTPCodes();
}
})
.catch(err => {
wx.showToast({
title: '获取OTP列表失败',
icon: 'none'
});
this.setData({ loading: false });
});
},
// 刷新所有OTP的验证码
refreshOTPCodes() {
const { otpList } = this.data;
// 为每个OTP获取当前验证码
const promises = otpList.map(otp => {
return getOTPCode(otp.id)
.then(res => {
if (res.data && res.data.code) {
return {
id: otp.id,
code: res.data.code,
expiresIn: res.data.expires_in || 30
};
}
return null;
})
.catch(() => null);
});
Promise.all(promises).then(results => {
const updatedList = [...this.data.otpList];
results.forEach(result => {
if (result) {
const index = updatedList.findIndex(otp => otp.id === result.id);
if (index !== -1) {
updatedList[index] = {
...updatedList[index],
currentCode: result.code,
expiresIn: result.expiresIn
};
}
}
});
this.setData({ otpList: updatedList });
// 设置定时器,每秒更新倒计时
this.startCountdown();
});
},
// 开始倒计时
startCountdown() {
// 清除之前的定时器
if (this.countdownTimer) {
clearInterval(this.countdownTimer);
}
// 创建新的定时器,每秒更新一次
this.countdownTimer = setInterval(() => {
const { otpList } = this.data;
let needRefresh = false;
const updatedList = otpList.map(otp => {
if (!otp.countdown) {
otp.countdown = otp.expiresIn || 30;
}
otp.countdown -= 1;
// 如果倒计时结束,标记需要刷新
if (otp.countdown <= 0) {
needRefresh = true;
}
return otp;
});
this.setData({ otpList: updatedList });
// 如果有OTP需要刷新重新获取验证码
if (needRefresh) {
this.refreshOTPCodes();
}
}, 1000);
},
// 添加新的OTP
handleAddOTP() {
wx.navigateTo({
url: '/pages/otp-add/index'
});
},
// 编辑OTP
handleEditOTP(e) {
const { id } = e.currentTarget.dataset;
wx.navigateTo({
url: `/pages/otp-edit/index?id=${id}`
});
},
// 删除OTP
handleDeleteOTP(e) {
const { id, name } = e.currentTarget.dataset;
wx.showModal({
title: '确认删除',
content: `确定要删除 ${name} 吗?`,
confirmColor: '#ff4d4f',
success: (res) => {
if (res.confirm) {
deleteOTP(id)
.then(() => {
wx.showToast({
title: '删除成功',
icon: 'success'
});
this.fetchOTPList();
})
.catch(err => {
wx.showToast({
title: '删除失败',
icon: 'none'
});
});
}
}
});
},
// 复制验证码
handleCopyCode(e) {
const { code } = e.currentTarget.dataset;
wx.setClipboardData({
data: code,
success: () => {
wx.showToast({
title: '验证码已复制',
icon: 'success'
});
}
});
},
onUnload() {
// 页面卸载时清除定时器
if (this.countdownTimer) {
clearInterval(this.countdownTimer);
}
}
});

View file

@ -1,3 +0,0 @@
{
"usingComponents": {}
}

View file

@ -1,59 +0,0 @@
<!-- otp-list/index.wxml -->
<view class="container">
<view class="header">
<text class="title">我的OTP列表</text>
<view class="add-button" bindtap="handleAddOTP">
<text class="add-icon">+</text>
</view>
</view>
<!-- 加载中 -->
<view class="loading-container" wx:if="{{loading}}">
<view class="loading-spinner"></view>
<text class="loading-text">加载中...</text>
</view>
<!-- OTP列表 -->
<view class="otp-list" wx:else>
<block wx:if="{{otpList.length > 0}}">
<view class="otp-item" wx:for="{{otpList}}" wx:key="id">
<view class="otp-info">
<view class="otp-name-row">
<text class="otp-name">{{item.name}}</text>
<text class="otp-issuer">{{item.issuer}}</text>
</view>
<view class="otp-code-row" bindtap="handleCopyCode" data-code="{{item.currentCode}}">
<text class="otp-code">{{item.currentCode || '******'}}</text>
<text class="copy-hint">点击复制</text>
</view>
<view class="otp-countdown">
<progress
percent="{{(item.countdown / item.expiresIn) * 100}}"
stroke-width="3"
activeColor="{{item.countdown < 10 ? '#ff4d4f' : '#1890ff'}}"
backgroundColor="#e9e9e9"
/>
<text class="countdown-text">{{item.countdown || 0}}s</text>
</view>
</view>
<view class="otp-actions">
<view class="action-button edit" bindtap="handleEditOTP" data-id="{{item.id}}">
<text class="action-icon">✎</text>
</view>
<view class="action-button delete" bindtap="handleDeleteOTP" data-id="{{item.id}}" data-name="{{item.name}}">
<text class="action-icon">✕</text>
</view>
</view>
</view>
</block>
<!-- 空状态 -->
<view class="empty-state" wx:else>
<image class="empty-image" src="/assets/images/empty.png" mode="aspectFit"></image>
<text class="empty-text">暂无OTP点击右上角添加</text>
</view>
</view>
</view>

View file

@ -1,201 +0,0 @@
/* otp-list/index.wxss */
.container {
min-height: 100vh;
background-color: #f8f8f8;
padding: 0 0 40rpx 0;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 40rpx 32rpx;
background-color: #ffffff;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
}
.title {
font-size: 36rpx;
font-weight: bold;
color: #333333;
}
.add-button {
width: 64rpx;
height: 64rpx;
border-radius: 32rpx;
background-color: #1890ff;
display: flex;
align-items: center;
justify-content: center;
box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3);
}
.add-icon {
color: #ffffff;
font-size: 40rpx;
line-height: 1;
}
/* 加载状态 */
.loading-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 120rpx 0;
}
.loading-spinner {
width: 64rpx;
height: 64rpx;
border: 6rpx solid #f3f3f3;
border-top: 6rpx solid #1890ff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
.loading-text {
margin-top: 20rpx;
font-size: 28rpx;
color: #999999;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* OTP列表 */
.otp-list {
padding: 20rpx 32rpx;
}
.otp-item {
background-color: #ffffff;
border-radius: 16rpx;
padding: 32rpx;
margin-bottom: 20rpx;
display: flex;
justify-content: space-between;
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.otp-info {
flex: 1;
margin-right: 20rpx;
}
.otp-name-row {
display: flex;
align-items: center;
margin-bottom: 16rpx;
}
.otp-name {
font-size: 32rpx;
font-weight: bold;
color: #333333;
margin-right: 16rpx;
}
.otp-issuer {
font-size: 24rpx;
color: #666666;
background-color: #f5f5f5;
padding: 4rpx 12rpx;
border-radius: 8rpx;
}
.otp-code-row {
display: flex;
align-items: center;
margin-bottom: 20rpx;
}
.otp-code {
font-size: 44rpx;
font-family: monospace;
font-weight: bold;
color: #1890ff;
letter-spacing: 4rpx;
margin-right: 16rpx;
}
.copy-hint {
font-size: 24rpx;
color: #999999;
}
.otp-countdown {
position: relative;
width: 100%;
}
.countdown-text {
position: absolute;
right: 0;
top: -30rpx;
font-size: 24rpx;
color: #999999;
}
/* OTP操作按钮 */
.otp-actions {
display: flex;
flex-direction: column;
justify-content: space-between;
}
.action-button {
width: 56rpx;
height: 56rpx;
border-radius: 28rpx;
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 16rpx;
}
.action-button.edit {
background-color: #f0f7ff;
}
.action-button.delete {
background-color: #fff1f0;
}
.action-icon {
font-size: 32rpx;
}
.edit .action-icon {
color: #1890ff;
}
.delete .action-icon {
color: #ff4d4f;
}
/* 空状态 */
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 120rpx 0;
}
.empty-image {
width: 240rpx;
height: 240rpx;
margin-bottom: 32rpx;
}
.empty-text {
font-size: 28rpx;
color: #999999;
}

View file

@ -1,47 +0,0 @@
{
"description": "项目配置文件",
"packOptions": {
"ignore": [],
"include": []
},
"miniprogramRoot": "",
"compileType": "miniprogram",
"projectname": "OTPM",
"setting": {
"useCompilerPlugins": [
"sass"
],
"babelSetting": {
"ignore": [],
"disablePlugins": [],
"outputPath": ""
},
"es6": true,
"enhance": true,
"minified": true,
"postcss": true,
"minifyWXSS": true,
"minifyWXML": true,
"uglifyFileName": true,
"packNpmManually": false,
"packNpmRelationList": [],
"ignoreUploadUnusedFiles": true,
"compileWorklet": false,
"uploadWithSourceMap": true,
"localPlugins": false,
"disableUseStrict": false,
"condition": false,
"swc": false,
"disableSWC": true
},
"simulatorType": "wechat",
"simulatorPluginLibVersion": {},
"condition": {},
"srcMiniprogramRoot": "",
"appid": "wxb6599459668b6b55",
"libVersion": "2.30.2",
"editorSetting": {
"tabIndent": "insertSpaces",
"tabSize": 2
}
}

View file

@ -1,23 +0,0 @@
{
"libVersion": "3.8.5",
"projectname": "OTPM",
"condition": {},
"setting": {
"urlCheck": true,
"coverView": false,
"lazyloadPlaceholderEnable": false,
"skylineRenderEnable": false,
"preloadBackgroundData": false,
"autoAudits": false,
"useApiHook": true,
"useApiHostProcess": true,
"showShadowRootInWxmlPanel": false,
"useStaticServer": false,
"useLanDebug": false,
"showES6CompileOption": false,
"compileHotReLoad": true,
"checkInvalidKey": true,
"ignoreDevUnusedFiles": true,
"bigPackageSizeSupport": false
}
}

View file

@ -1,84 +0,0 @@
// auth.js - 认证相关服务
import request from '../utils/request';
/**
* 微信登录
* 1. 调用wx.login获取code
* 2. 发送code到服务端换取token和openid
* 3. 保存token和openid到本地存储
*/
export const wxLogin = () => {
return new Promise((resolve, reject) => {
wx.login({
success: (res) => {
if (res.code) {
// 发送code到服务端
request({
url: '/login',
method: 'POST',
data: {
code: res.code
}
}).then(response => {
// 保存token和openid
if (response.data && response.data.token && response.data.openid) {
wx.setStorageSync('token', response.data.token);
wx.setStorageSync('openid', response.data.openid);
resolve(response.data);
} else {
reject(new Error('登录失败,服务器返回数据格式错误'));
}
}).catch(err => {
reject(err);
});
} else {
reject(new Error('登录失败获取code失败: ' + res.errMsg));
}
},
fail: (err) => {
reject(new Error('微信登录失败: ' + err.errMsg));
}
});
});
};
/**
* 检查登录状态
* 1. 检查本地是否有token和openid
* 2. 如果有验证token是否有效
* 3. 如果无效清除本地存储并返回false
*/
export const checkLoginStatus = () => {
return new Promise((resolve, reject) => {
const token = wx.getStorageSync('token');
const openid = wx.getStorageSync('openid');
if (!token || !openid) {
resolve(false);
return;
}
// 验证token有效性
request({
url: '/verify-token',
method: 'POST'
}).then(() => {
resolve(true);
}).catch(() => {
// token无效清除本地存储
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
resolve(false);
});
});
};
/**
* 退出登录
*/
export const logout = () => {
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
return Promise.resolve();
};

View file

@ -1,119 +0,0 @@
// otp.js - OTP相关服务
import request from '../utils/request';
/**
* 创建新的OTP
* @param {Object} params - 创建OTP的参数
* @param {string} params.name - OTP名称
* @param {string} params.issuer - 发行方
* @param {string} params.secret - 密钥
* @param {string} params.algorithm - 算法默认为SHA1
* @param {number} params.digits - 位数默认为6
* @param {number} params.period - 周期默认为30秒
* @returns {Promise} - 返回创建结果
*/
export const createOTP = (params) => {
if (!params || !params.secret) {
return Promise.reject(new Error('缺少必要的参数: secret'));
}
return request({
url: '/otp',
method: 'POST',
data: {
name: params.name || '',
issuer: params.issuer || '',
secret: params.secret,
algorithm: params.algorithm || 'SHA1',
digits: params.digits || 6,
period: params.period || 30
}
}).catch(err => {
console.error('创建OTP失败:', err);
throw new Error('创建OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 获取用户所有OTP列表
* @returns {Promise} - 返回OTP列表
*/
export const getOTPList = () => {
return request({
url: '/otp',
method: 'GET'
}).catch(err => {
console.error('获取OTP列表失败:', err);
throw new Error('获取OTP列表失败: ' + (err.message || '未知错误'));
});
};
/**
* 获取指定OTP的当前验证码
* @param {string} id - OTP的ID
* @returns {Promise} - 返回当前验证码
*/
export const getOTPCode = (id) => {
if (!id) {
return Promise.reject(new Error('缺少必要的参数: id'));
}
return request({
url: `/otp/${id}/code`,
method: 'GET'
}).catch(err => {
console.error('获取OTP代码失败:', err);
throw new Error('获取OTP代码失败: ' + (err.message || '未知错误'));
});
};
/**
* 验证OTP
* @param {string} id - OTP的ID
* @param {string} code - 用户输入的验证码
* @returns {Promise} - 返回验证结果
*/
export const verifyOTP = (id, code) => {
if (!id || !code) {
return Promise.reject(new Error('缺少必要的参数: id或code'));
}
return request({
url: `/otp/${id}/verify`,
method: 'POST',
data: { code }
}).catch(err => {
console.error('验证OTP失败:', err);
throw new Error('验证OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 更新OTP信息
* @param {string} id - OTP的ID
* @param {Object} params - 更新的参数
* @returns {Promise} - 返回更新结果
*/
export const updateOTP = (id, params) => {
if (!id || !params) {
return Promise.reject(new Error('缺少必要的参数: id或params'));
}
return request({
url: `/otp/${id}`,
method: 'PUT',
data: params
}).catch(err => {
console.error('更新OTP失败:', err);
throw new Error('更新OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 删除OTP
* @param {string} id - OTP的ID
* @returns {Promise} - 返回删除结果
*/
export const deleteOTP = (id) => {
return request({
url: `/otp/${id}`,
method: 'DELETE'
});
};

View file

@ -1,58 +0,0 @@
// request.js - 网络请求工具类
const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名
// 请求拦截器
const request = (options) => {
return new Promise((resolve, reject) => {
const token = wx.getStorageSync('token');
const header = {
'Content-Type': 'application/json',
...options.header
};
// 如果有token添加到请求头
if (token) {
header['Authorization'] = `Bearer ${token}`;
}
wx.request({
url: `${BASE_URL}${options.url}`,
method: options.method || 'GET',
data: options.data,
header: header,
success: (res) => {
// 处理业务错误
if (res.data.code !== 0) {
// token过期直接清除并跳转登录
if (res.statusCode === 401) {
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
reject(new Error('登录已过期,请重新登录'));
return;
}
reject(new Error(res.data.message || '请求失败'));
return;
}
resolve(res.data);
},
fail: reject
});
});
};
// 刷新token
const refreshToken = () => {
return request({
url: '/refresh-token',
method: 'POST'
}).then(res => {
if (res.data && res.data.token) {
wx.setStorageSync('token', res.data.token);
return res.data.token;
}
throw new Error('Failed to refresh token');
});
};
export default request;

View file

@ -1,66 +0,0 @@
package models
import (
"context"
"time"
)
// OTP represents a TOTP configuration
type OTP struct {
ID int64 `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id" validate:"required"`
OpenID string `json:"openid" db:"openid" validate:"required"`
Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OTPParams represents common OTP parameters used in creation and update
type OTPParams struct {
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
Period int `json:"period" validate:"omitempty,min=30,max=60"`
}
// OTPRepository handles OTP data storage
type OTPRepository struct {
// Add your database connection or ORM here
}
// Create creates a new OTP record
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
// Implement database creation logic
return nil
}
// FindByID finds an OTP by ID and user ID
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
// Implement database lookup logic
return nil, nil
}
// FindAllByUserID finds all OTPs for a user
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
// Implement database query logic
return nil, nil
}
// Update updates an existing OTP record
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
// Implement database update logic
return nil
}
// Delete deletes an OTP record
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
// Implement database deletion logic
return nil
}

View file

@ -1,114 +0,0 @@
package models
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/jmoiron/sqlx"
)
// User represents a user in the system
type User struct {
ID string `db:"id" json:"id"`
OpenID string `db:"openid" json:"openid"`
SessionKey string `db:"session_key" json:"-"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// UserRepository handles user data operations
type UserRepository struct {
db *sqlx.DB
}
// NewUserRepository creates a new UserRepository
func NewUserRepository(db *sqlx.DB) *UserRepository {
return &UserRepository{db: db}
}
// FindByID finds a user by ID
func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
var user User
query := `SELECT * FROM users WHERE id = ?`
err := r.db.GetContext(ctx, &user, query, id)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found: %w", err)
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
return &user, nil
}
// FindByOpenID finds a user by OpenID
func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) {
var user User
query := `SELECT * FROM users WHERE openid = ?`
err := r.db.GetContext(ctx, &user, query, openID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // User not found, but not an error
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
return &user, nil
}
// Create creates a new user
func (r *UserRepository) Create(ctx context.Context, user *User) error {
query := `
INSERT INTO users (id, openid, session_key, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
`
now := time.Now()
user.CreatedAt = now
user.UpdatedAt = now
_, err := r.db.ExecContext(
ctx,
query,
user.ID,
user.OpenID,
user.SessionKey,
user.CreatedAt,
user.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// Update updates an existing user
func (r *UserRepository) Update(ctx context.Context, user *User) error {
query := `
UPDATE users
SET session_key = ?, updated_at = ?
WHERE id = ?
`
user.UpdatedAt = time.Now()
_, err := r.db.ExecContext(
ctx,
query,
user.SessionKey,
user.UpdatedAt,
user.ID,
)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// Delete deletes a user
func (r *UserRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM users WHERE id = ?`
_, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}

417
otp_api.go Normal file
View file

@ -0,0 +1,417 @@
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
_ "github.com/lib/pq"
)
// 加密密钥(32字节AES-256)
var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
// encryptTokenSecret 加密令牌密钥
func encryptTokenSecret(secret string) (string, error) {
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
ciphertext := make([]byte, aes.BlockSize+len(secret))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// decryptTokenSecret 解密令牌密钥
func decryptTokenSecret(encrypted string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
if len(ciphertext) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)
return string(ciphertext), nil
}
// SaveRequest 保存请求的数据结构
type SaveRequest struct {
Tokens []TokenData `json:"tokens"`
UserID string `json:"userId"`
Timestamp int64 `json:"timestamp"`
}
// TokenData token数据结构
type TokenData struct {
ID string `json:"id"`
Issuer string `json:"issuer"`
Account string `json:"account"`
Secret string `json:"secret"`
Type string `json:"type"`
Counter int `json:"counter,omitempty"`
Period int `json:"period"`
Digits int `json:"digits"`
Algo string `json:"algo"`
}
// SaveResponse 保存响应的数据结构
type SaveResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data struct {
ID string `json:"id"`
} `json:"data"`
}
// RecoverResponse 恢复响应的数据结构
type RecoverResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data struct {
Tokens []TokenData `json:"tokens"`
Timestamp int64 `json:"timestamp"`
} `json:"data"`
}
var db *sql.DB
// InitDB 初始化数据库连接
func InitDB() error {
connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
var err error
db, err = sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("error opening database: %v", err)
}
if err = db.Ping(); err != nil {
return fmt.Errorf("error connecting to the database: %v", err)
}
return nil
}
// SaveHandler 保存token的接口处理函数
func SaveHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求
var req SaveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" {
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
if len(req.Tokens) == 0 {
sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
return
}
// 开始数据库事务
tx, err := db.Begin()
if err != nil {
log.Printf("Error starting transaction: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// 删除用户现有的tokens
_, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
if err != nil {
log.Printf("Error deleting existing tokens: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 插入新的tokens
stmt, err := tx.Prepare(`
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`)
if err != nil {
log.Printf("Error preparing statement: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer stmt.Close()
for _, token := range req.Tokens {
// 加密secret
encryptedSecret, err := encryptTokenSecret(token.Secret)
if err != nil {
log.Printf("Error encrypting token secret: %v", err)
sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
return
}
_, err = stmt.Exec(
token.ID,
req.UserID,
token.Issuer,
token.Account,
encryptedSecret,
token.Type,
token.Counter,
token.Period,
token.Digits,
token.Algo,
req.Timestamp,
)
if err != nil {
log.Printf("Error inserting token: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
}
// 提交事务
if err = tx.Commit(); err != nil {
log.Printf("Error committing transaction: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 返回成功响应
resp := SaveResponse{
Success: true,
Message: "Tokens saved successfully",
}
resp.Data.ID = req.UserID
sendJSONResponse(w, resp, http.StatusOK)
}
// RecoverHandler 恢复token的接口处理函数
func RecoverHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var req struct {
UserID string `json:"userId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证用户ID
if req.UserID == "" {
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
// 查询数据库
rows, err := db.Query(`
SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
FROM tokens
WHERE user_id = $1
ORDER BY timestamp DESC
`, req.UserID)
if err != nil {
log.Printf("Error querying database: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer rows.Close()
// 读取查询结果
var tokens []TokenData
var timestamp int64
for rows.Next() {
var token TokenData
var encryptedSecret string
err := rows.Scan(
&token.ID,
&token.Issuer,
&token.Account,
&encryptedSecret,
&token.Type,
&token.Counter,
&token.Period,
&token.Digits,
&token.Algo,
&timestamp,
)
if err != nil {
log.Printf("Error scanning row: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 解密secret
token.Secret, err = decryptTokenSecret(encryptedSecret)
if err != nil {
log.Printf("Error decrypting token secret: %v", err)
sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
return
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
log.Printf("Error iterating rows: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 返回响应
resp := RecoverResponse{
Success: true,
Message: "Tokens recovered successfully",
}
resp.Data.Tokens = tokens
resp.Data.Timestamp = timestamp
sendJSONResponse(w, resp, http.StatusOK)
}
// sendErrorResponse 发送错误响应
func sendErrorResponse(w http.ResponseWriter, message string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": message,
})
}
// DeleteTokenHandler 删除单个token的接口处理函数
func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求
var req struct {
UserID string `json:"userId"`
TokenID string `json:"tokenId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" || req.TokenID == "" {
sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
return
}
// 执行删除操作
result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
if err != nil {
log.Printf("Error deleting token: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 检查是否真的删除了记录
rowsAffected, err := result.RowsAffected()
if err != nil {
log.Printf("Error getting rows affected: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
if rowsAffected == 0 {
sendErrorResponse(w, "Token not found", http.StatusNotFound)
return
}
// 返回成功响应
sendJSONResponse(w, map[string]interface{}{
"success": true,
"message": "Token deleted successfully",
}, http.StatusOK)
}
// sendJSONResponse 发送JSON响应
func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Error encoding response: %v", err)
sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
}
}

111
otp_api_test.go Normal file
View file

@ -0,0 +1,111 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestSaveHandler(t *testing.T) {
// 创建测试服务器
srv := httptest.NewServer(http.HandlerFunc(SaveHandler))
defer srv.Close()
// 准备测试数据
testData := SaveRequest{
UserID: "test_user_123",
Tokens: []TokenData{
{
Issuer: "TestOrg",
Account: "user@test.com",
Secret: "JBSWY3DPEHPK3PXP",
Type: "totp",
Period: 30,
Digits: 6,
Algo: "SHA1",
},
},
}
// 序列化请求体
body, _ := json.Marshal(testData)
// 发送请求
resp, err := http.Post(srv.URL, "application/json", bytes.NewBuffer(body))
if err != nil {
t.Fatalf("Error making request to server: %v\n", err)
}
defer resp.Body.Close()
// 检查响应状态码
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, resp.StatusCode)
}
// 解析响应
var saveResp SaveResponse
err = json.NewDecoder(resp.Body).Decode(&saveResp)
if err != nil {
t.Errorf("Error decoding response: %v\n", err)
}
// 验证响应数据
if !saveResp.Success {
t.Errorf("Expected success to be true, got false\n")
}
if saveResp.Message != "Tokens saved successfully" {
t.Errorf("Expected message to be 'Tokens saved successfully', got '%s'\n", saveResp.Message)
}
}
func TestRecoverHandler(t *testing.T) {
// 创建测试服务器
srv := httptest.NewServer(http.HandlerFunc(RecoverHandler))
defer srv.Close()
// 发送请求没有user_id参数
resp, err := http.Get(srv.URL)
if err != nil {
t.Fatalf("Error making request to server: %v\n", err)
}
defer resp.Body.Close()
// 检查响应状态码(应该返回错误)
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status: %d, got: %d\n", http.StatusBadRequest, resp.StatusCode)
}
// 发送带user_id的请求
urlWithID := fmt.Sprintf("%s?user_id=test_user_123", srv.URL)
respWithID, err := http.Get(urlWithID)
if err != nil {
t.Fatalf("Error making request to server: %v\n", err)
}
defer respWithID.Body.Close()
// 检查响应状态码
if respWithID.StatusCode != http.StatusOK {
t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, respWithID.StatusCode)
}
// 解析响应
var recoverResp RecoverResponse
err = json.NewDecoder(respWithID.Body).Decode(&recoverResp)
if err != nil {
t.Errorf("Error decoding response: %v\n", err)
}
// 验证响应数据
if !recoverResp.Success {
t.Errorf("Expected success to be true, got false\n")
}
if recoverResp.Message != "Tokens recovered successfully" {
t.Errorf("Expected message to be 'Tokens recovered successfully', got '%s'\n", recoverResp.Message)
}
if len(recoverResp.Tokens) != 1 {
t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Tokens))
}
}

View file

@ -1,332 +0,0 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/crypto/argon2"
)
// SecurityService provides security functionality
type SecurityService struct {
config *Config
}
// Config represents security configuration
type Config struct {
// CSRF protection
CSRFTokenLength int
CSRFTokenExpiry time.Duration
CSRFCookieName string
CSRFHeaderName string
CSRFCookieSecure bool
CSRFCookieHTTPOnly bool
CSRFCookieSameSite http.SameSite
// Rate limiting
RateLimitRequests int
RateLimitWindow time.Duration
// Password hashing
Argon2Time uint32
Argon2Memory uint32
Argon2Threads uint8
Argon2KeyLen uint32
Argon2SaltLen uint32
}
// DefaultConfig returns the default security configuration
func DefaultConfig() *Config {
return &Config{
// CSRF protection
CSRFTokenLength: 32,
CSRFTokenExpiry: 24 * time.Hour,
CSRFCookieName: "csrf_token",
CSRFHeaderName: "X-CSRF-Token",
CSRFCookieSecure: true,
CSRFCookieHTTPOnly: true,
CSRFCookieSameSite: http.SameSiteStrictMode,
// Rate limiting
RateLimitRequests: 100,
RateLimitWindow: time.Minute,
// Password hashing
Argon2Time: 1,
Argon2Memory: 64 * 1024,
Argon2Threads: 4,
Argon2KeyLen: 32,
Argon2SaltLen: 16,
}
}
// NewSecurityService creates a new SecurityService
func NewSecurityService(config *Config) *SecurityService {
if config == nil {
config = DefaultConfig()
}
return &SecurityService{
config: config,
}
}
// GenerateCSRFToken generates a CSRF token
func (s *SecurityService) GenerateCSRFToken() (string, error) {
// Generate random bytes
bytes := make([]byte, s.config.CSRFTokenLength)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode as base64
token := base64.StdEncoding.EncodeToString(bytes)
return token, nil
}
// SetCSRFCookie sets a CSRF cookie
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
http.SetCookie(w, &http.Cookie{
Name: s.config.CSRFCookieName,
Value: token,
Path: "/",
Expires: time.Now().Add(s.config.CSRFTokenExpiry),
Secure: s.config.CSRFCookieSecure,
HttpOnly: s.config.CSRFCookieHTTPOnly,
SameSite: s.config.CSRFCookieSameSite,
})
}
// ValidateCSRFToken validates a CSRF token
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
// Get token from cookie
cookie, err := r.Cookie(s.config.CSRFCookieName)
if err != nil {
return false
}
cookieToken := cookie.Value
// Get token from header
headerToken := r.Header.Get(s.config.CSRFHeaderName)
// Compare tokens
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
}
// CSRFMiddleware creates a middleware that validates CSRF tokens
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip for GET, HEAD, OPTIONS, TRACE
if r.Method == http.MethodGet ||
r.Method == http.MethodHead ||
r.Method == http.MethodOptions ||
r.Method == http.MethodTrace {
next.ServeHTTP(w, r)
return
}
// Validate CSRF token
if !s.ValidateCSRFToken(r) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// RateLimiter represents a rate limiter
type RateLimiter struct {
requests map[string][]time.Time
config *Config
}
// NewRateLimiter creates a new RateLimiter
func NewRateLimiter(config *Config) *RateLimiter {
return &RateLimiter{
requests: make(map[string][]time.Time),
config: config,
}
}
// Allow checks if a request is allowed
func (r *RateLimiter) Allow(key string) bool {
now := time.Now()
windowStart := now.Add(-r.config.RateLimitWindow)
// Get requests for key
requests := r.requests[key]
// Filter out old requests
var newRequests []time.Time
for _, t := range requests {
if t.After(windowStart) {
newRequests = append(newRequests, t)
}
}
// Check if rate limit is exceeded
if len(newRequests) >= r.config.RateLimitRequests {
return false
}
// Add current request
newRequests = append(newRequests, now)
r.requests[key] = newRequests
return true
}
// RateLimitMiddleware creates a middleware that limits request rate
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP
ip := getClientIP(r)
// Check if request is allowed
if !limiter.Allow(ip) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// getClientIP gets the client IP address
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, use the first one
ips := strings.Split(xForwardedFor, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xRealIP := r.Header.Get("X-Real-IP")
if xRealIP != "" {
return xRealIP
}
// Use RemoteAddr
return r.RemoteAddr
}
// HashPassword hashes a password using Argon2
func (s *SecurityService) HashPassword(password string) (string, error) {
// Generate salt
salt := make([]byte, s.config.Argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
// Hash password
hash := argon2.IDKey(
[]byte(password),
salt,
s.config.Argon2Time,
s.config.Argon2Memory,
s.config.Argon2Threads,
s.config.Argon2KeyLen,
)
// Encode as base64
saltBase64 := base64.StdEncoding.EncodeToString(salt)
hashBase64 := base64.StdEncoding.EncodeToString(hash)
// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
return fmt.Sprintf(
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
s.config.Argon2Memory,
s.config.Argon2Time,
s.config.Argon2Threads,
saltBase64,
hashBase64,
), nil
}
// VerifyPassword verifies a password against a hash
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
// Parse encoded hash
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
// Extract parameters
if parts[1] != "argon2id" {
return false, fmt.Errorf("unsupported hash algorithm")
}
var memory uint32
var time uint32
var threads uint8
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
if err != nil {
return false, fmt.Errorf("failed to parse hash parameters: %w", err)
}
// Decode salt and hash
salt, err := base64.StdEncoding.DecodeString(parts[4])
if err != nil {
return false, fmt.Errorf("failed to decode salt: %w", err)
}
hash, err := base64.StdEncoding.DecodeString(parts[5])
if err != nil {
return false, fmt.Errorf("failed to decode hash: %w", err)
}
// Hash password with same parameters
newHash := argon2.IDKey(
[]byte(password),
salt,
time,
memory,
threads,
uint32(len(hash)),
)
// Compare hashes
return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
}
// SecureHeadersMiddleware adds security headers to responses
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
next.ServeHTTP(w, r)
})
}
// contextKey is a type for context keys
type contextKey string
// userIDKey is the key for user ID in context
const userIDKey = contextKey("user_id")
// WithUserID adds a user ID to the context
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}
// GetUserID gets the user ID from the context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userIDKey).(string)
return userID, ok
}

View file

@ -1,190 +0,0 @@
package server
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"otpm/config"
"otpm/middleware"
"github.com/julienschmidt/httprouter"
)
// Server represents the HTTP server
type Server struct {
server *http.Server
router *httprouter.Router
config *config.Config
}
// New creates a new server
func New(cfg *config.Config) *Server {
router := httprouter.New()
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: 120 * time.Second,
}
return &Server{
server: server,
router: router,
config: cfg,
}
}
// Start starts the server
func (s *Server) Start() error {
// Apply global middleware in correct order with enhanced error handling
var handler http.Handler = s.router
// Logger should be first to capture all request details
handler = middleware.Logger(handler)
// CORS next to handle pre-flight requests
handler = middleware.CORS(handler)
// Then Timeout to enforce request deadlines
handler = middleware.Timeout(s.config.Server.Timeout)(handler)
// Recover should be outermost to catch any panics
handler = middleware.Recover(handler)
s.server.Handler = handler
// Log server configuration at startup
log.Printf("Server configuration:\n"+
"Address: %s\n"+
"Read Timeout: %v\n"+
"Write Timeout: %v\n"+
"Idle Timeout: %v\n"+
"Request Timeout: %v",
s.server.Addr,
s.server.ReadTimeout,
s.server.WriteTimeout,
s.server.IdleTimeout,
s.config.Server.Timeout,
)
// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
log.Printf("Server starting on %s", s.server.Addr)
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
serverErr <- fmt.Errorf("server error: %w", err)
}
}()
// Wait for interrupt signal or server error
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-serverErr:
return err
case <-quit:
return s.Shutdown()
}
}
// Shutdown gracefully stops the server
func (s *Server) Shutdown() error {
log.Println("Shutting down server...")
ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("graceful shutdown failed: %w", err)
}
log.Println("Server stopped gracefully")
return nil
}
// Router returns the router
func (s *Server) Router() *httprouter.Router {
return s.router
}
// RegisterRoutes registers all routes
func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
s.router.Handle("GET", pattern, handler)
s.router.Handle("POST", pattern, handler)
s.router.Handle("PUT", pattern, handler)
s.router.Handle("DELETE", pattern, handler)
}
}
// RegisterAuthRoutes registers routes that require authentication
func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
// Apply authentication middleware
authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Convert httprouter.Handle to http.HandlerFunc for middleware
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Store params in request context
ctx := context.WithValue(r.Context(), "params", ps)
handler(w, r.WithContext(ctx), ps)
})
// Apply auth middleware
middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
}
s.router.Handle("GET", pattern, authHandler)
s.router.Handle("POST", pattern, authHandler)
s.router.Handle("PUT", pattern, authHandler)
s.router.Handle("DELETE", pattern, authHandler)
}
}
// RegisterHealthCheck registers an enhanced health check endpoint
func (s *Server) RegisterHealthCheck() {
s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
response := map[string]interface{}{
"status": "ok",
"timestamp": time.Now().Format(time.RFC3339),
"version": "1.0.0", // Hardcoded version instead of from config
"system": map[string]interface{}{
"goroutines": runtime.NumGoroutine(),
"memory": getMemoryUsage(),
},
}
// Add database status if configured
if s.config.Database.DSN != "" {
dbStatus := "ok"
response["database"] = dbStatus
}
middleware.SuccessResponse(w, response)
})
}
// getMemoryUsage returns current memory usage in MB
func getMemoryUsage() map[string]interface{} {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return map[string]interface{}{
"alloc_mb": bToMb(m.Alloc),
"total_alloc_mb": bToMb(m.TotalAlloc),
"sys_mb": bToMb(m.Sys),
"num_gc": m.NumGC,
}
}
func bToMb(b uint64) float64 {
return float64(b) / 1024 / 1024
}

View file

@ -1,230 +0,0 @@
package services
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"otpm/config"
"otpm/models"
)
// WeChatCode2SessionResponse represents the response from WeChat code2session API
type WeChatCode2SessionResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// AuthService handles authentication related operations
type AuthService struct {
config *config.Config
userRepo *models.UserRepository
httpClient *http.Client
}
// NewAuthService creates a new AuthService
func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
return &AuthService{
config: cfg,
userRepo: userRepo,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// LoginWithWeChatCode handles WeChat login
func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
start := time.Now()
// Get OpenID and SessionKey from WeChat
sessionInfo, err := s.getWeChatSession(code)
if err != nil {
log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
return "", fmt.Errorf("failed to get WeChat session: %w", err)
}
log.Printf("WeChat session obtained for code %s (took %v)",
maskCode(code), time.Since(start))
// Find or create user
user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
if err != nil {
log.Printf("User lookup failed for OpenID %s: %v",
maskOpenID(sessionInfo.OpenID), err)
return "", fmt.Errorf("failed to find user: %w", err)
}
if user == nil {
// Create new user
user = &models.User{
ID: uuid.New().String(),
OpenID: sessionInfo.OpenID,
SessionKey: sessionInfo.SessionKey,
}
if err := s.userRepo.Create(ctx, user); err != nil {
log.Printf("User creation failed for OpenID %s: %v",
maskOpenID(sessionInfo.OpenID), err)
return "", fmt.Errorf("failed to create user: %w", err)
}
log.Printf("New user created with ID %s for OpenID %s",
user.ID, maskOpenID(sessionInfo.OpenID))
} else {
// Update session key
user.SessionKey = sessionInfo.SessionKey
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("User update failed for ID %s: %v", user.ID, err)
return "", fmt.Errorf("failed to update user: %w", err)
}
log.Printf("User %s session key updated", user.ID)
}
// Generate JWT token
token, err := s.generateToken(user)
if err != nil {
log.Printf("Token generation failed for user %s: %v", user.ID, err)
return "", fmt.Errorf("failed to generate token: %w", err)
}
log.Printf("WeChat login completed for user %s (total time %v)",
user.ID, time.Since(start))
return token, nil
}
// maskCode masks sensitive parts of WeChat code for logging
func maskCode(code string) string {
if len(code) < 8 {
return "****"
}
return code[:2] + "****" + code[len(code)-2:]
}
// maskOpenID masks sensitive parts of OpenID for logging
func maskOpenID(openID string) string {
if len(openID) < 8 {
return "****"
}
return openID[:2] + "****" + openID[len(openID)-2:]
}
// getWeChatSession calls WeChat's code2session API
func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
url := fmt.Sprintf(
"https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.config.WeChat.AppID,
s.config.WeChat.AppSecret,
code,
)
resp, err := s.httpClient.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to call WeChat API: %w", err)
}
defer resp.Body.Close()
var result WeChatCode2SessionResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
}
if result.ErrCode != 0 {
return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
}
return &result, nil
}
// generateToken generates a JWT token for a user
func (s *AuthService) generateToken(user *models.User) (string, error) {
now := time.Now()
claims := jwt.MapClaims{
"user_id": user.ID,
"exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
"iat": now.Unix(),
"iss": s.config.JWT.Issuer,
"aud": s.config.JWT.Audience,
"token_id": uuid.New().String(), // Unique token ID for tracking
}
// Use stronger signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
log.Printf("Token generated for user %s (expires at %v)",
user.ID, now.Add(s.config.JWT.ExpireDelta))
return signedToken, nil
}
// ValidateToken validates a JWT token with additional checks
func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.config.JWT.Secret), nil
})
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
switch {
case ve.Errors&jwt.ValidationErrorMalformed != 0:
return nil, fmt.Errorf("malformed token")
case ve.Errors&jwt.ValidationErrorExpired != 0:
return nil, fmt.Errorf("token expired")
case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
return nil, fmt.Errorf("token not active yet")
default:
return nil, fmt.Errorf("token validation error: %w", err)
}
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// Additional claims validation
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// Check issuer
if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
return nil, fmt.Errorf("invalid token issuer")
}
// Check audience
if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
return nil, fmt.Errorf("invalid token audience")
}
} else {
return nil, fmt.Errorf("invalid token claims")
}
return token, nil
}
// GetUserFromToken gets user information from a JWT token
func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, fmt.Errorf("user_id not found in token")
}
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
}
return user, nil
}

View file

@ -1,358 +0,0 @@
package services
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"log"
"strings"
"time"
"otpm/models"
"github.com/google/uuid"
)
// OTPService handles OTP related operations
type OTPService struct {
otpRepo *models.OTPRepository
}
// NewOTPService creates a new OTPService
func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
return &OTPService{
otpRepo: otpRepo,
}
}
// CreateOTP creates a new OTP with performance monitoring and logging
func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
start := time.Now()
// Validate input
if err := s.validateOTPInput(input); err != nil {
log.Printf("OTP validation failed for user %s: %v", userID, err)
return nil, err
}
// Clean and standardize secret
secret := cleanSecret(input.Secret)
// Set defaults for optional fields
algorithm := strings.ToUpper(input.Algorithm)
if algorithm == "" {
algorithm = "SHA1"
}
digits := input.Digits
if digits == 0 {
digits = 6
}
period := input.Period
if period == 0 {
period = 30
}
// Create OTP
otp := &models.OTP{
ID: uuid.New().String(),
UserID: userID,
Name: input.Name,
Issuer: input.Issuer,
Secret: secret,
Algorithm: algorithm,
Digits: digits,
Period: period,
}
if err := s.otpRepo.Create(ctx, otp); err != nil {
log.Printf("Failed to create OTP for user %s: %v", userID, err)
return nil, fmt.Errorf("failed to create OTP: %w", err)
}
// Log successful creation (without exposing secret)
log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
return otp, nil
}
// GetOTP gets an OTP by ID
func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
return otp, nil
}
// ListOTPs lists all OTPs for a user
func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to list OTPs: %w", err)
}
return otps, nil
}
// UpdateOTP updates an OTP
func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
// Get existing OTP
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
// Update fields
if input.Name != "" {
otp.Name = input.Name
}
if input.Issuer != "" {
otp.Issuer = input.Issuer
}
if input.Algorithm != "" {
otp.Algorithm = strings.ToUpper(input.Algorithm)
}
if input.Digits > 0 {
otp.Digits = input.Digits
}
if input.Period > 0 {
otp.Period = input.Period
}
// Validate updated OTP
if err := s.validateOTPInput(models.OTPParams{
Name: otp.Name,
Issuer: otp.Issuer,
Secret: otp.Secret,
Algorithm: otp.Algorithm,
Digits: otp.Digits,
Period: otp.Period,
}); err != nil {
return nil, err
}
if err := s.otpRepo.Update(ctx, otp); err != nil {
return nil, fmt.Errorf("failed to update OTP: %w", err)
}
return otp, nil
}
// DeleteOTP deletes an OTP
func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
return fmt.Errorf("failed to delete OTP: %w", err)
}
return nil
}
// GenerateCode generates a TOTP code with enhanced logging and error handling
func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
start := time.Now()
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
return "", 0, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current time step
now := time.Now().Unix()
timeStep := now / int64(otp.Period)
// Generate code
code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
return "", 0, fmt.Errorf("failed to generate code: %w", err)
}
// Calculate remaining seconds
remainingSeconds := otp.Period - int(now%int64(otp.Period))
// Log successful generation (without actual code)
log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
id, userID, time.Since(start), remainingSeconds)
return code, remainingSeconds, nil
}
// VerifyCode verifies a TOTP code with enhanced security and logging
func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
start := time.Now()
// Basic input validation
if len(code) == 0 {
log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
return false, fmt.Errorf("code is required")
}
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s during verification: %v",
id, userID, err)
return false, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current and adjacent time steps
now := time.Now().Unix()
timeSteps := []int64{
(now - int64(otp.Period)) / int64(otp.Period),
now / int64(otp.Period),
(now + int64(otp.Period)) / int64(otp.Period),
}
// Check code against all time steps
for _, ts := range timeSteps {
expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Code generation failed for time step %d: %v", ts, err)
continue
}
if expectedCode == code {
// Log successful verification
log.Printf("Code verified successfully for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return true, nil
}
}
// Log failed verification attempt
log.Printf("Invalid code provided for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return false, nil
}
// validateOTPInput validates OTP input with detailed error messages
func (s *OTPService) validateOTPInput(input models.OTPParams) error {
if input.Name == "" {
return fmt.Errorf("name is required")
}
if len(input.Name) > 100 {
return fmt.Errorf("name is too long (maximum 100 characters)")
}
if input.Secret == "" {
return fmt.Errorf("secret is required")
}
if !isValidBase32(input.Secret) {
return fmt.Errorf("invalid secret format: must be a valid base32 string")
}
// Secret length check (after base32 decoding)
secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
if len(secretBytes) < 10 {
return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
}
if input.Algorithm != "" {
if !isValidAlgorithm(input.Algorithm) {
return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
}
}
if input.Digits != 0 {
if input.Digits < 6 || input.Digits > 8 {
return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
}
}
if input.Period != 0 {
if input.Period < 30 || input.Period > 60 {
return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
}
}
return nil
}
// Helper functions
func cleanSecret(secret string) string {
// Remove spaces and convert to upper case
secret = strings.TrimSpace(strings.ToUpper(secret))
// Remove any padding characters
return strings.TrimRight(secret, "=")
}
func isValidBase32(s string) bool {
// Try to decode the secret
_, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
return err == nil
}
func isValidAlgorithm(algorithm string) bool {
switch strings.ToUpper(algorithm) {
case "SHA1", "SHA256", "SHA512":
return true
default:
return false
}
}
func getHasher(algorithm string, key []byte) (hash.Hash, error) {
switch strings.ToUpper(algorithm) {
case "SHA1":
return hmac.New(sha1.New, key), nil
case "SHA256":
return hmac.New(sha256.New, key), nil
case "SHA512":
return hmac.New(sha512.New, key), nil
default:
return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
}
}
func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
// Decode secret
secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
if err != nil {
return "", fmt.Errorf("invalid secret: %w", err)
}
// Get initialized HMAC hasher with secret
hasher, err := getHasher(algorithm, secretBytes)
if err != nil {
return "", err
}
// Convert time step to bytes
timeBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
// Calculate HMAC
hasher.Write(timeBytes)
hash := hasher.Sum(nil)
// Get offset
offset := hash[len(hash)-1] & 0xf
// Generate 4-byte code
code := binary.BigEndian.Uint32(hash[offset : offset+4])
code = code & 0x7fffffff
// Get the specified number of digits
code = code % uint32(pow10(digits))
// Format code with leading zeros
return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
}
func pow10(n int) uint32 {
result := uint32(1)
for i := 0; i < n; i++ {
result *= 10
}
return result
}

View file

@ -1,105 +0,0 @@
package utils
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
"github.com/spf13/viper"
)
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
// 返回一个httprouter.Handle函数该函数接受http.ResponseWriter和*http.Request作为参数
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// 调用传入的http.Handler函数将http.ResponseWriter和*http.Request作为参数传递
h(w, r)
}
}
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
return
}
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
secret := viper.GetString("auth.secret")
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(secret), nil
})
if err != nil || !token.Valid {
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
return
}
type contextKey string
// 将 openid 存入上下文
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
next.ServeHTTP(w, r.WithContext(ctx))
}
}
// AesDecrypt 函数用于AES解密
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
//Base64解码
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
if err != nil {
return nil, err
}
ivBytes, err := base64.StdEncoding.DecodeString(iv)
if err != nil {
return nil, err
}
cryptData, err := base64.StdEncoding.DecodeString(encryptedData)
if err != nil {
return nil, err
}
origData := make([]byte, len(cryptData))
//AES
block, err := aes.NewCipher(keyBytes)
if err != nil {
return nil, err
}
//CBC
mode := cipher.NewCBCDecrypter(block, ivBytes)
//解密
mode.CryptBlocks(origData, cryptData)
//去除填充位
origData = PKCS7UnPadding(origData)
return origData, nil
}
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
func PKCS7UnPadding(plantText []byte) []byte {
// 获取密文的长度
length := len(plantText)
// 如果密文长度大于0
if length > 0 {
// 获取最后一个字节的值,即填充的位数
unPadding := int(plantText[length-1])
// 返回去除填充后的密文
return plantText[:(length - unPadding)]
}
// 如果密文长度为0则返回原密文
return plantText
}

View file

@ -1,316 +0,0 @@
package validator
import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var (
validate *validator.Validate
// 自定义验证规则
customValidations = map[string]validator.Func{
"otpsecret": validateOTPSecret,
"password": validatePassword,
"issuer": validateIssuer,
"otpauth_uri": validateOTPAuthURI,
"no_xss": validateNoXSS,
}
// 常见的弱密码列表(实际使用时应该使用更完整的列表)
commonPasswords = map[string]bool{
"password123": true,
"12345678": true,
"qwerty123": true,
"admin123": true,
"letmein": true,
"welcome": true,
"password": true,
"admin": true,
}
// 预编译的XSS检测正则表达式
xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
}
// 预编译的正则表达式
base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
upperRegex = regexp.MustCompile(`[A-Z]`)
lowerRegex = regexp.MustCompile(`[a-z]`)
numberRegex = regexp.MustCompile(`[0-9]`)
specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
)
func init() {
validate = validator.New()
// 注册自定义验证规则
for tag, fn := range customValidations {
if err := validate.RegisterValidation(tag, fn); err != nil {
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
}
}
// 使用json tag作为字段名
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// ValidateRequest validates a request body against a struct
func ValidateRequest(r *http.Request, v interface{}) error {
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := validate.Struct(v); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
return NewValidationError(validationErrors)
}
return fmt.Errorf("validation error: %w", err)
}
return nil
}
// ValidationError represents a validation error
type ValidationError struct {
Fields map[string]string `json:"fields"`
}
// Error implements the error interface
func (e *ValidationError) Error() string {
var errors []string
for field, msg := range e.Fields {
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
}
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
}
// NewValidationError creates a new ValidationError from validator.ValidationErrors
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
fields := make(map[string]string)
for _, err := range errors {
fields[err.Field()] = getErrorMessage(err)
}
return &ValidationError{Fields: fields}
}
// getErrorMessage returns a human-readable error message for a validation error
func getErrorMessage(err validator.FieldError) string {
switch err.Tag() {
case "required":
return "此字段为必填项"
case "email":
return "请输入有效的电子邮件地址"
case "min":
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
}
return fmt.Sprintf("必须大于或等于 %s", err.Param())
case "max":
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
}
return fmt.Sprintf("必须小于或等于 %s", err.Param())
case "len":
return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
case "oneof":
return fmt.Sprintf("必须是以下值之一: %s", err.Param())
case "otpsecret":
return "OTP密钥格式无效必须是有效的Base32编码"
case "password":
return "密码必须至少10个字符并包含大写字母、小写字母以及数字或特殊字符"
case "issuer":
return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
case "otpauth_uri":
return "OTP认证URI格式无效"
case "no_xss":
return "输入包含潜在的不安全内容"
case "numeric":
return "必须是数字"
default:
return fmt.Sprintf("验证失败: %s", err.Tag())
}
}
// Custom validation functions
// validateOTPSecret validates an OTP secret
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
if secret == "" {
return false
}
// OTP secret should be base32 encoded
if !base32Regex.MatchString(secret) {
return false
}
// Check length (typical OTP secrets are 16-64 characters)
validLength := len(secret) >= 16 && len(secret) <= 128
return validLength
}
// validatePassword validates a password
func validatePassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
// At least 10 characters long
if len(password) < 10 {
return false
}
// Check if it's a common password
if commonPasswords[strings.ToLower(password)] {
return false
}
// Check character types
hasUpper := upperRegex.MatchString(password)
hasLower := lowerRegex.MatchString(password)
hasNumber := numberRegex.MatchString(password)
hasSpecial := specialRegex.MatchString(password)
// Ensure password has enough complexity
complexity := 0
if hasUpper {
complexity++
}
if hasLower {
complexity++
}
if hasNumber {
complexity++
}
if hasSpecial {
complexity++
}
return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
}
// validateIssuer validates an issuer name
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
if issuer == "" {
return false
}
// Issuer should not contain special characters that could cause problems in URLs
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
validLength := len(issuer) >= 1 && len(issuer) <= 100
return validLength
}
// validateOTPAuthURI validates an otpauth URI
func validateOTPAuthURI(fl validator.FieldLevel) bool {
uri := fl.Field().String()
if uri == "" {
return false
}
// Basic format check for otpauth URI
// Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
return otpauthRegex.MatchString(uri)
}
// validateNoXSS checks if a string contains potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// 检查基本的HTML编码
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// 检查十六进制编码
if strings.Contains(strings.ToLower(value), "\\x3c") || // <
strings.Contains(strings.ToLower(value), "\\x3e") { // >
return false
}
// 检查Unicode编码
if strings.Contains(strings.ToLower(value), "\\u003c") || // <
strings.Contains(strings.ToLower(value), "\\u003e") { // >
return false
}
// 使用预编译的正则表达式检查XSS模式
for _, pattern := range xssPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}
// Request validation structs
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"required,oneof=6 8"`
Period int `json:"period" validate:"required,oneof=30 60"`
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
Period int `json:"period" validate:"omitempty,oneof=30 60"`
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}