Compare commits

...
Sign in to create a new pull request.

7 commits
master ... beta

Author SHA1 Message Date
“xHuPo”
44500afd3f beta 2025-05-23 19:09:06 +08:00
“xHuPo”
bcd986e3f7 beta 2025-05-23 18:57:11 +08:00
“xHuPo”
a45ddf13d5 alpha 2025-05-23 14:15:48 +08:00
“xHuPo”
a6461a9a0e alpha 2025-05-23 13:45:53 +08:00
“xHuPo”
2d3698716e alpha 2025-05-23 13:45:37 +08:00
“xHuPo”
25c5f530b8 alpha 2025-05-22 16:07:55 +08:00
“xHuPo”
079542e431 alpha 2025-05-22 12:06:34 +08:00
49 changed files with 6260 additions and 284 deletions

149
api/response.go Normal file
View file

@ -0,0 +1,149 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"net/http"
)
// Common error codes
const (
CodeSuccess = 0
CodeInvalidParams = 400
CodeUnauthorized = 401
CodeForbidden = 403
CodeNotFound = 404
CodeInternalError = 500
CodeServiceUnavail = 503
)
// Error represents an API error
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
}
// Error implements the error interface
func (e *Error) Error() string {
return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
}
// NewError creates a new API error
func NewError(code int, message string) *Error {
return &Error{
Code: code,
Message: message,
}
}
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ResponseWriter wraps common response writing functions
type ResponseWriter struct {
http.ResponseWriter
}
// NewResponseWriter creates a new ResponseWriter
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{w}
}
// WriteJSON writes a JSON response
func (w *ResponseWriter) WriteJSON(code int, data interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
return json.NewEncoder(w).Encode(data)
}
// WriteSuccess writes a success response
func (w *ResponseWriter) WriteSuccess(data interface{}) error {
return w.WriteJSON(http.StatusOK, Response{
Code: CodeSuccess,
Message: "success",
Data: data,
})
}
// WriteError writes an error response
func (w *ResponseWriter) WriteError(err error) error {
var apiErr *Error
if errors.As(err, &apiErr) {
return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{
Code: apiErr.Code,
Message: apiErr.Message,
})
}
// Handle unknown errors
return w.WriteJSON(http.StatusInternalServerError, Response{
Code: CodeInternalError,
Message: "Internal Server Error",
})
}
// WriteErrorWithCode writes an error response with a specific code
func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error {
return w.WriteJSON(getHTTPStatus(code), Response{
Code: code,
Message: message,
})
}
// getHTTPStatus maps API error codes to HTTP status codes
func getHTTPStatus(code int) int {
switch code {
case CodeSuccess:
return http.StatusOK
case CodeInvalidParams:
return http.StatusBadRequest
case CodeUnauthorized:
return http.StatusUnauthorized
case CodeForbidden:
return http.StatusForbidden
case CodeNotFound:
return http.StatusNotFound
case CodeServiceUnavail:
return http.StatusServiceUnavailable
default:
return http.StatusInternalServerError
}
}
// Common errors
var (
ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters")
ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized")
ErrForbidden = NewError(CodeForbidden, "Forbidden")
ErrNotFound = NewError(CodeNotFound, "Resource not found")
ErrInternalError = NewError(CodeInternalError, "Internal server error")
ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable")
)
// ValidationError creates an error for invalid parameters
func ValidationError(message string) *Error {
return NewError(CodeInvalidParams, message)
}
// NotFoundError creates an error for not found resources
func NotFoundError(resource string) *Error {
return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource))
}
// ForbiddenError creates an error for forbidden actions
func ForbiddenError(message string) *Error {
return NewError(CodeForbidden, message)
}
// InternalError creates an error for internal server errors
func InternalError(err error) *Error {
if err == nil {
return ErrInternalError
}
return NewError(CodeInternalError, err.Error())
}

206
cache/cache.go vendored Normal file
View file

@ -0,0 +1,206 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
)
// Item represents a cache item
type Item struct {
Value []byte
Expiration int64
}
// Expired returns true if the item has expired
func (item Item) Expired() bool {
if item.Expiration == 0 {
return false
}
return time.Now().UnixNano() > item.Expiration
}
// Cache represents an in-memory cache
type Cache struct {
items map[string]Item
mu sync.RWMutex
defaultExpiration time.Duration
cleanupInterval time.Duration
stopCleanup chan bool
}
// New creates a new cache with the given default expiration and cleanup interval
func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
cache := &Cache{
items: make(map[string]Item),
defaultExpiration: defaultExpiration,
cleanupInterval: cleanupInterval,
stopCleanup: make(chan bool),
}
// Start cleanup goroutine if cleanup interval > 0
if cleanupInterval > 0 {
go cache.startCleanup()
}
return cache
}
// startCleanup starts the cleanup process
func (c *Cache) startCleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.DeleteExpired()
case <-c.stopCleanup:
return
}
}
}
// Set adds an item to the cache with the given key and expiration
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error {
// Convert value to bytes
var valueBytes []byte
var err error
switch v := value.(type) {
case []byte:
valueBytes = v
case string:
valueBytes = []byte(v)
default:
valueBytes, err = json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
}
// Calculate expiration
var exp int64
if expiration == 0 {
if c.defaultExpiration > 0 {
exp = time.Now().Add(c.defaultExpiration).UnixNano()
}
} else if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
c.mu.Lock()
c.items[key] = Item{
Value: valueBytes,
Expiration: exp,
}
c.mu.Unlock()
return nil
}
// Get gets an item from the cache
func (c *Cache) Get(key string, value interface{}) (bool, error) {
c.mu.RLock()
item, found := c.items[key]
c.mu.RUnlock()
if !found {
return false, nil
}
// Check if item has expired
if item.Expired() {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return false, nil
}
// Unmarshal value
switch v := value.(type) {
case *[]byte:
*v = item.Value
case *string:
*v = string(item.Value)
default:
if err := json.Unmarshal(item.Value, value); err != nil {
return true, fmt.Errorf("failed to unmarshal value: %w", err)
}
}
return true, nil
}
// Delete deletes an item from the cache
func (c *Cache) Delete(key string) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
}
// DeleteExpired deletes all expired items from the cache
func (c *Cache) DeleteExpired() {
now := time.Now().UnixNano()
c.mu.Lock()
for k, v := range c.items {
if v.Expiration > 0 && now > v.Expiration {
delete(c.items, k)
}
}
c.mu.Unlock()
}
// Clear deletes all items from the cache
func (c *Cache) Clear() {
c.mu.Lock()
c.items = make(map[string]Item)
c.mu.Unlock()
}
// Close stops the cleanup goroutine
func (c *Cache) Close() {
if c.cleanupInterval > 0 {
c.stopCleanup <- true
}
}
// CacheService provides caching functionality
type CacheService struct {
cache *Cache
}
// NewCacheService creates a new CacheService
func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService {
return &CacheService{
cache: New(defaultExpiration, cleanupInterval),
}
}
// Set adds an item to the cache
func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return s.cache.Set(key, value, expiration)
}
// Get gets an item from the cache
func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) {
return s.cache.Get(key, value)
}
// Delete deletes an item from the cache
func (s *CacheService) Delete(ctx context.Context, key string) {
s.cache.Delete(key)
}
// Clear deletes all items from the cache
func (s *CacheService) Clear(ctx context.Context) {
s.cache.Clear()
}
// Close closes the cache
func (s *CacheService) Close() {
s.cache.Close()
}

View file

@ -1,83 +1,140 @@
package cmd
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/spf13/viper"
"otpm/config"
"otpm/database"
"otpm/handlers"
"otpm/utils"
"github.com/jmoiron/sqlx"
"github.com/julienschmidt/httprouter"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"otpm/models"
"otpm/server"
"otpm/services"
)
var rootCmd = &cobra.Command{
Use: "otpm",
Short: "otp backend for microapp on wechat",
Run: func(cmd *cobra.Command, args []string) {
startApp()
},
}
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func init() {
cobra.OnInitialize(initConfig)
// Set config file with multi-environment support
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)")
rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)")
rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string")
rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on")
viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver"))
viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn"))
viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
}
func initConfig() {
if cfgFile := viper.GetString("config"); cfgFile != "" {
viper.SetConfigFile(cfgFile)
} else {
viper.AddConfigPath(".")
viper.SetConfigName("config")
viper.SetConfigType("yaml")
// Set environment specific config (e.g. config.production.yaml)
env := os.Getenv("OTPM_ENV")
if env != "" {
viper.SetConfigName(fmt.Sprintf("config.%s", env))
}
if err := viper.ReadInConfig(); err != nil {
log.Fatalf("Error reading config file: %v", err)
}
// Set default values
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.timeout.read", "15s")
viper.SetDefault("server.timeout.write", "15s")
viper.SetDefault("server.timeout.idle", "60s")
viper.SetDefault("database.max_open_conns", 25)
viper.SetDefault("database.max_idle_conns", 5)
viper.SetDefault("database.conn_max_lifetime", "5m")
// Set environment variable prefix
viper.SetEnvPrefix("OTPM")
viper.AutomaticEnv()
// Bind environment variables
viper.BindEnv("database.url", "OTPM_DB_URL")
viper.BindEnv("database.password", "OTPM_DB_PASSWORD")
}
func initApp(db *sqlx.DB) {
if err := database.MigrateDB(db); err != nil {
log.Fatalf("Error migrating the database: %v", err)
}
}
func startApp() {
port := viper.GetInt("port")
db, err := database.InitDB()
// Execute is the entry point for the application
func Execute() error {
// Load configuration
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("Error connecting to the database: %v", err)
return fmt.Errorf("failed to load config: %w", err)
}
// Create context with cancellation
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Printf("Received signal: %v", sig)
cancel()
}()
// Initialize database
db, err := database.New(&cfg.Database)
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
defer db.Close()
initApp(db)
handler := &handlers.Handler{DB: db}
router := httprouter.New()
router.POST("/login", utils.AdaptHandler(handler.Login))
router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
router.GET("/get", utils.AdaptHandler(handler.GetOtp))
// Run database migrations
if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
log.Println("Starting server on :8080")
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), router))
// Initialize repositories
userRepo := models.NewUserRepository(db.DB)
otpRepo := models.NewOTPRepository(db.DB)
// Initialize services
authService := services.NewAuthService(cfg, userRepo)
otpService := services.NewOTPService(otpRepo)
// Initialize handlers
authHandler := handlers.NewAuthHandler(authService)
otpHandler := handlers.NewOTPHandler(otpService)
// Create and configure server
srv := server.New(cfg)
// Register health check endpoint
srv.RegisterHealthCheck()
// Register public routes with type conversion
authRoutes := make(map[string]http.Handler)
for path, handler := range authHandler.Routes() {
authRoutes[path] = http.HandlerFunc(handler)
}
srv.RegisterRoutes(authRoutes)
// Register authenticated routes with type conversion
otpRoutes := make(map[string]http.Handler)
for path, handler := range otpHandler.Routes() {
otpRoutes[path] = http.HandlerFunc(handler)
}
srv.RegisterAuthRoutes(otpRoutes)
// Start server in goroutine
serverErr := make(chan error, 1)
go func() {
log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port)
if err := srv.Start(); err != nil {
serverErr <- fmt.Errorf("server error: %w", err)
}
}()
// Wait for shutdown signal or server error
select {
case err := <-serverErr:
return err
case <-ctx.Done():
// Graceful shutdown with timeout
log.Println("Shutting down server...")
if err := srv.Shutdown(); err != nil {
return fmt.Errorf("server shutdown error: %w", err)
}
log.Println("Server stopped gracefully")
}
return nil
}

View file

@ -1,8 +1,23 @@
server:
port: 8080
read_timeout: 15s
write_timeout: 15s
shutdown_timeout: 5s
database:
driver: sqlite
driver: sqlite3
dsn: otpm.sqlite
port: 8080
max_open_conns: 25
max_idle_conns: 25
max_lifetime: 5m
skip_migration: false
jwt:
secret: "your-jwt-secret-key-change-this-in-production"
expire_delta: 24h
refresh_delta: 168h
signing_method: HS256
wechat:
appid: "wx57d1033974eb5250"
secret: "be494c2a81df685a40b9a74e1736b15d"
app_id: "your-wechat-app-id"
app_secret: "your-wechat-app-secret"

128
config/config.go Normal file
View file

@ -0,0 +1,128 @@
package config
import (
"fmt"
"time"
"github.com/spf13/viper"
)
// Config holds all configuration for the application
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
JWT JWTConfig `mapstructure:"jwt"`
WeChat WeChatConfig `mapstructure:"wechat"`
}
// ServerConfig holds all server related configuration
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout
}
// DatabaseConfig holds all database related configuration
type DatabaseConfig struct {
Driver string `mapstructure:"driver"`
DSN string `mapstructure:"dsn"`
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
MaxLifetime time.Duration `mapstructure:"max_lifetime"`
SkipMigration bool `mapstructure:"skip_migration"`
}
// JWTConfig holds all JWT related configuration
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireDelta time.Duration `mapstructure:"expire_delta"`
RefreshDelta time.Duration `mapstructure:"refresh_delta"`
SigningMethod string `mapstructure:"signing_method"`
Issuer string `mapstructure:"issuer"`
Audience string `mapstructure:"audience"`
}
// WeChatConfig holds all WeChat related configuration
type WeChatConfig struct {
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
}
// LoadConfig loads the configuration from file and environment variables
func LoadConfig() (*Config, error) {
// Set default values
setDefaults()
// Read config file
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Validate config
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
return &config, nil
}
// setDefaults sets default values for configuration
func setDefaults() {
// Server defaults
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.read_timeout", "15s")
viper.SetDefault("server.write_timeout", "15s")
viper.SetDefault("server.shutdown_timeout", "5s")
viper.SetDefault("server.timeout", "30s") // Default request processing timeout
// Database defaults
viper.SetDefault("database.driver", "sqlite3")
viper.SetDefault("database.max_open_conns", 25)
viper.SetDefault("database.max_idle_conns", 25)
viper.SetDefault("database.max_lifetime", "5m")
viper.SetDefault("database.skip_migration", false)
// JWT defaults
viper.SetDefault("jwt.expire_delta", "24h")
viper.SetDefault("jwt.refresh_delta", "168h") // 7 days
viper.SetDefault("jwt.signing_method", "HS256")
viper.SetDefault("jwt.issuer", "otpm")
viper.SetDefault("jwt.audience", "otpm-client")
}
// validateConfig validates the configuration
func validateConfig(config *Config) error {
if config.Server.Port < 1 || config.Server.Port > 65535 {
return fmt.Errorf("invalid port number: %d", config.Server.Port)
}
if config.Database.Driver == "" {
return fmt.Errorf("database driver is required")
}
if config.Database.DSN == "" {
return fmt.Errorf("database DSN is required")
}
if config.JWT.Secret == "" {
return fmt.Errorf("JWT secret is required")
}
if config.WeChat.AppID == "" {
return fmt.Errorf("WeChat AppID is required")
}
if config.WeChat.AppSecret == "" {
return fmt.Errorf("WeChat AppSecret is required")
}
return nil
}

View file

@ -1,49 +0,0 @@
package database
import (
_ "embed"
"log"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/spf13/viper"
_ "modernc.org/sqlite"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
func InitDB() (*sqlx.DB, error) {
driver := viper.GetString("database.driver")
dsn := viper.GetString("database.dsn")
db, err := sqlx.Open(driver, dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
log.Println("Connected to database!")
return db, nil
}
func MigrateDB(db *sqlx.DB) error {
_, err := db.Exec(userTable)
if err != nil {
return err
}
_, err = db.Exec(otpTable)
if err != nil {
return err
}
return nil
}

196
database/db.go Normal file
View file

@ -0,0 +1,196 @@
package database
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"otpm/config"
"github.com/jmoiron/sqlx"
)
// DB wraps sqlx.DB to provide additional functionality
type DB struct {
*sqlx.DB
}
// New creates a new database connection
func New(cfg *config.DatabaseConfig) (*DB, error) {
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Configure connection pool with optimized settings
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
// Verify connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Println("Successfully connected to database")
return &DB{db}, nil
}
// WithTx executes a function within a transaction with retry logic
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
const maxRetries = 3
var lastErr error
// Default transaction options
opts := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
}
for attempt := 1; attempt <= maxRetries; attempt++ {
start := time.Now()
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
continue
}
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
}
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
}
// Log long-running transactions
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
log.Printf("Transaction completed in %v", elapsed)
}
if err := tx.Commit(); err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
return nil
}
return lastErr
}
// isRetryableError checks if an error is likely to succeed on retry
func isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "deadlock") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "try again") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "busy") ||
strings.Contains(errStr, "locked")
}
// ExecContext executes a query with adaptive timeout
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "insert") ||
strings.Contains(strings.ToLower(query), "update") ||
strings.Contains(strings.ToLower(query), "delete") {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
_, err := db.DB.ExecContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
}
if err != nil {
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return nil
}
// QueryRowContext executes a query that returns a single row with timeout
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return db.DB.QueryRowxContext(ctx, query, args...)
}
// QueryContext executes a query that returns multiple rows with adaptive timeout
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "join") ||
strings.Contains(strings.ToLower(query), "group by") ||
strings.Contains(strings.ToLower(query), "order by") {
timeout = 15 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
rows, err := db.DB.QueryxContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
}
if err != nil {
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return rows, nil
}
// Close closes the database connection
func (db *DB) Close() error {
if err := db.DB.Close(); err != nil {
return fmt.Errorf("failed to close database connection: %w", err)
}
return nil
}

View file

@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS otp (
id SERIAL PRIMARY KEY,
openid VARCHAR(255),
num INTEGER,
token VARCHAR(255)
id SERIAL PRIMARY KEY,
openid VARCHAR(255) UNIQUE NOT NULL,
token VARCHAR(255),
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

View file

@ -2,4 +2,5 @@ CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
);
);
CREATE UNIQUE INDEX idx_users_openid ON users(openid);

160
database/migration.go Normal file
View file

@ -0,0 +1,160 @@
package database
import (
"context"
"fmt"
"log"
"time"
_ "embed"
"github.com/jmoiron/sqlx"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
// Migration represents a database migration
type Migration struct {
Name string
SQL string
Version int
}
// Migrations is a list of all migrations
var Migrations = []Migration{
{
Name: "Create users table",
SQL: userTable,
Version: 1,
},
{
Name: "Create OTP table",
SQL: otpTable,
Version: 2,
},
}
// MigrationRecord represents a record in the migrations table
type MigrationRecord struct {
ID int `db:"id"`
Version int `db:"version"`
Name string `db:"name"`
AppliedAt time.Time `db:"applied_at"`
}
// ensureMigrationsTable ensures that the migrations table exists
func ensureMigrationsTable(db *sqlx.DB) error {
query := `
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version INTEGER NOT NULL,
name TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
return nil
}
// getAppliedMigrations gets all applied migrations
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
var records []MigrationRecord
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
result := make(map[int]MigrationRecord)
for _, record := range records {
result[record.Version] = record
}
return result, nil
}
// Migrate runs all pending migrations
func Migrate(db *sqlx.DB, skipMigration bool) error {
if skipMigration {
log.Println("Skipping database migration as configured")
return nil
}
// Ensure migrations table exists
if err := ensureMigrationsTable(db); err != nil {
return err
}
// Get applied migrations
appliedMigrations, err := getAppliedMigrations(db)
if err != nil {
return err
}
// Apply pending migrations
for _, migration := range Migrations {
if _, ok := appliedMigrations[migration.Version]; ok {
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
continue
}
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
// Start a transaction for this migration
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// Execute migration
if _, err := tx.Exec(migration.SQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Record migration
if _, err := tx.Exec(
"INSERT INTO migrations (version, name) VALUES (?, ?)",
migration.Version, migration.Name,
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
}
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
}
return nil
}
// MigrateWithContext runs all pending migrations with context
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
// Create a channel to signal completion
done := make(chan error, 1)
// Run migration in a goroutine
go func() {
done <- Migrate(db, skipMigration)
}()
// Wait for migration to complete or context to be canceled
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("migration canceled: %w", ctx.Err())
}
}

663
docs/swagger.go Normal file
View file

@ -0,0 +1,663 @@
// Package docs provides API documentation using Swagger/OpenAPI
package docs
import (
"encoding/json"
"net/http"
)
// SwaggerInfo holds the API information
var SwaggerInfo = struct {
Title string
Description string
Version string
Host string
BasePath string
Schemes []string
}{
Title: "OTPM API",
Description: "API for One-Time Password Manager",
Version: "1.0.0",
Host: "localhost:8080",
BasePath: "/",
Schemes: []string{"http", "https"},
}
// SwaggerJSON returns the Swagger JSON
func SwaggerJSON() []byte {
swagger := map[string]interface{}{
"swagger": "2.0",
"info": map[string]interface{}{
"title": SwaggerInfo.Title,
"description": SwaggerInfo.Description,
"version": SwaggerInfo.Version,
},
"host": SwaggerInfo.Host,
"basePath": SwaggerInfo.BasePath,
"schemes": SwaggerInfo.Schemes,
"paths": getPaths(),
"definitions": map[string]interface{}{
"LoginRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "WeChat authorization code",
},
},
"required": []string{"code"},
},
"LoginResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"token": map[string]interface{}{
"type": "string",
"description": "JWT token",
},
"openid": map[string]interface{}{
"type": "string",
"description": "WeChat OpenID",
},
},
},
"CreateOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"secret": map[string]interface{}{
"type": "string",
"description": "OTP secret",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
"enum": []string{"SHA1", "SHA256", "SHA512"},
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
"enum": []int{6, 8},
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
"enum": []int{30, 60},
},
},
"required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"},
},
"OTP": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "string",
"description": "OTP ID",
},
"user_id": map[string]interface{}{
"type": "string",
"description": "User ID",
},
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
},
"created_at": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Creation time",
},
"updated_at": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Last update time",
},
},
},
"OTPCodeResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "OTP code",
},
"expires_in": map[string]interface{}{
"type": "integer",
"description": "Seconds until expiration",
},
},
},
"VerifyOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "string",
"description": "OTP code to verify",
},
},
"required": []string{"code"},
},
"VerifyOTPResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"valid": map[string]interface{}{
"type": "boolean",
"description": "Whether the code is valid",
},
},
},
"UpdateOTPRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "OTP name",
},
"issuer": map[string]interface{}{
"type": "string",
"description": "OTP issuer",
},
"algorithm": map[string]interface{}{
"type": "string",
"description": "OTP algorithm",
"enum": []string{"SHA1", "SHA256", "SHA512"},
},
"digits": map[string]interface{}{
"type": "integer",
"description": "OTP digits",
"enum": []int{6, 8},
},
"period": map[string]interface{}{
"type": "integer",
"description": "OTP period in seconds",
"enum": []int{30, 60},
},
},
},
"ErrorResponse": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"code": map[string]interface{}{
"type": "integer",
"description": "Error code",
},
"message": map[string]interface{}{
"type": "string",
"description": "Error message",
},
},
},
},
"securityDefinitions": map[string]interface{}{
"Bearer": map[string]interface{}{
"type": "apiKey",
"name": "Authorization",
"in": "header",
"description": "JWT token with Bearer prefix",
},
},
}
data, _ := json.MarshalIndent(swagger, "", " ")
return data
}
// getPaths returns the API paths
func getPaths() map[string]interface{} {
return map[string]interface{}{
"/login": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Login with WeChat",
"description": "Login with WeChat authorization code",
"tags": []string{"auth"},
"parameters": []map[string]interface{}{
{
"name": "body",
"in": "body",
"description": "Login request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/LoginRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Successful login",
"schema": map[string]interface{}{
"$ref": "#/definitions/LoginResponse",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/verify-token": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Verify token",
"description": "Verify JWT token",
"tags": []string{"auth"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Token is valid",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"valid": map[string]interface{}{
"type": "boolean",
"description": "Whether the token is valid",
},
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp": map[string]interface{}{
"get": map[string]interface{}{
"summary": "List OTPs",
"description": "List all OTPs for the authenticated user",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "List of OTPs",
"schema": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
"post": map[string]interface{}{
"summary": "Create OTP",
"description": "Create a new OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "body",
"in": "body",
"description": "OTP creation request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/CreateOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP created",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}": map[string]interface{}{
"put": map[string]interface{}{
"summary": "Update OTP",
"description": "Update an existing OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
{
"name": "body",
"in": "body",
"description": "OTP update request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/UpdateOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP updated",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTP",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
"delete": map[string]interface{}{
"summary": "Delete OTP",
"description": "Delete an existing OTP",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP deleted",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "Success message",
},
},
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}/code": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Get OTP code",
"description": "Get the current OTP code",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP code",
"schema": map[string]interface{}{
"$ref": "#/definitions/OTPCodeResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/otp/{id}/verify": map[string]interface{}{
"post": map[string]interface{}{
"summary": "Verify OTP code",
"description": "Verify an OTP code",
"tags": []string{"otp"},
"security": []map[string]interface{}{
{
"Bearer": []string{},
},
},
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"description": "OTP ID",
"required": true,
"type": "string",
},
{
"name": "body",
"in": "body",
"description": "OTP verification request",
"required": true,
"schema": map[string]interface{}{
"$ref": "#/definitions/VerifyOTPRequest",
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "OTP verification result",
"schema": map[string]interface{}{
"$ref": "#/definitions/VerifyOTPResponse",
},
},
"400": map[string]interface{}{
"description": "Invalid request",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"401": map[string]interface{}{
"description": "Unauthorized",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"404": map[string]interface{}{
"description": "OTP not found",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
"500": map[string]interface{}{
"description": "Internal server error",
"schema": map[string]interface{}{
"$ref": "#/definitions/ErrorResponse",
},
},
},
},
},
"/health": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Health check",
"description": "Check if the API is healthy",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "API is healthy",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"status": map[string]interface{}{
"type": "string",
"description": "Health status",
},
"time": map[string]interface{}{
"type": "string",
"format": "date-time",
"description": "Current time",
},
},
},
},
},
},
},
"/metrics": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Metrics",
"description": "Get application metrics",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Application metrics",
},
},
},
},
"/swagger.json": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Swagger JSON",
"description": "Get Swagger JSON",
"tags": []string{"system"},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "Swagger JSON",
},
},
},
},
}
}
// Handler returns an HTTP handler for Swagger JSON
func Handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(SwaggerJSON())
}
}

24
go.mod
View file

@ -1,6 +1,8 @@
module otpm
go 1.21.1
go 1.23.0
toolchain go1.23.9
require (
github.com/go-sql-driver/mysql v1.8.1
@ -14,17 +16,30 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.26.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/prometheus/client_golang v1.22.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
@ -35,9 +50,12 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect

38
go.sum
View file

@ -1,5 +1,9 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -11,10 +15,21 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@ -33,6 +48,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
@ -43,6 +60,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
@ -50,6 +69,14 @@ github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
@ -87,17 +114,28 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

147
handlers/auth_handler.go Normal file
View file

@ -0,0 +1,147 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"otpm/api"
"otpm/services"
"github.com/golang-jwt/jwt"
)
// AuthHandler handles authentication related requests
type AuthHandler struct {
authService *services.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code"`
}
// LoginResponse represents a login response
type LoginResponse struct {
Token string `json:"token"`
OpenID string `json:"openid"`
}
// Login handles WeChat login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Limit request body size to prevent DOS
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
// Parse request
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("Login request parse error: %v", err)
return
}
// Validate request
if req.Code == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Code is required")
log.Printf("Login request validation failed: empty code")
return
}
// Login with WeChat code
token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("Login failed for code %s: %v", req.Code, err)
return
}
// Log successful login
log.Printf("Login successful for code %s (took %v)",
req.Code, time.Since(start))
// Return token
api.NewResponseWriter(w).WriteSuccess(LoginResponse{
Token: token,
})
}
// VerifyToken handles token verification
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Authorization header is required")
log.Printf("Token verification failed: missing Authorization header")
return
}
// Validate token format
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token format. Expected 'Bearer <token>'")
log.Printf("Token verification failed: invalid token format")
return
}
token := authHeader[7:]
if len(token) < 32 { // Basic length check
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token length")
log.Printf("Token verification failed: token too short")
return
}
// Validate token
claims, err := h.authService.ValidateToken(token)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("Token verification failed for token %s: %v",
maskToken(token), err) // Mask token in logs
return
}
// Log successful verification
userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
if !ok {
log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
} else {
log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
}
// Token is valid
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": true,
})
}
// maskToken masks sensitive parts of token for logging
func maskToken(token string) string {
if len(token) < 8 {
return "****"
}
return token[:4] + "****" + token[len(token)-4:]
}
// Routes returns all routes for the auth handler
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/login": h.Login,
"/verify-token": h.VerifyToken,
}
}

View file

@ -1,9 +0,0 @@
package handlers
import (
"github.com/jmoiron/sqlx"
)
type Handler struct {
DB *sqlx.DB
}

View file

@ -1,93 +0,0 @@
package handlers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/spf13/viper"
)
var code2Session = "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code"
type LoginRequest struct {
Code string `json:"code"`
}
// 封装code2session接口返回数据
type LoginResponse struct {
OpenId string `json:"openid"`
SessionKey string `json:"session_key"`
UnionId string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
func getLoginResponse(code string) (*LoginResponse, error) {
appid := viper.GetString("wechat.appid")
secret := viper.GetString("wechat.secret")
url := fmt.Sprintf(code2Session, appid, secret, code)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var loginResponse LoginResponse
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
return nil, err
}
if loginResponse.ErrCode != 0 {
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
}
return &loginResponse, nil
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Failed to parse request body", http.StatusBadRequest)
return
}
loginResponse, err := getLoginResponse(req.Code)
if err != nil {
http.Error(w, "Failed to get session key", http.StatusInternalServerError)
return
}
// // 插入或更新用户的openid和sessionid
// query := `
// INSERT INTO users (openid, sessionid)
// VALUES ($1, $2)
// ON CONFLICT (openid) DO UPDATE SET sessionid = $2
// RETURNING id;
// `
// var ID int
// if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
// http.Error(w, "Failed to log in user", http.StatusInternalServerError)
// return
// }
data := map[string]interface{}{
"openid": loginResponse.OpenId,
"session_key": loginResponse.SessionKey,
}
respData, err := json.Marshal(data)
if err != nil {
http.Error(w, "Failed to marshal response data", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(respData))
}

View file

@ -1,61 +0,0 @@
package handlers
import (
"encoding/json"
"net/http"
)
type OtpRequest struct {
OpenID string `json:"openid"`
Num int `json:"num"`
Token *[]OTP `json:"token"`
}
type OTP struct {
Issuer string `json:"issuer"`
Remark string `json:"remark"`
Secret string `json:"secret"`
}
func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) {
var req OtpRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request payload", http.StatusBadRequest)
return
}
num := len(*req.Token)
// 插入或更新 OTP 记录
query := `
INSERT INTO otp (openid, num, token)
VALUES ($1, $2, $3)
`
_, err := h.DB.Exec(query, req.OpenID, req.Token, num)
if err != nil {
http.Error(w, "Failed to update or create OTP", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OTP updated or created successfully"))
}
func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
openid := r.URL.Query().Get("openid")
if openid == "" {
http.Error(w, "未登录", http.StatusBadRequest)
return
}
var otp OtpRequest
err := h.DB.Get(&otp, "SELECT token, num, openid FROM otp WHERE openid=$1", openid)
if err != nil {
http.Error(w, "OTP not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(otp)
}

286
handlers/otp_handler.go Normal file
View file

@ -0,0 +1,286 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"otpm/api"
"otpm/middleware"
"otpm/models"
"otpm/services"
)
// OTPHandler handles OTP related requests
type OTPHandler struct {
otpService *services.OTPService
}
// NewOTPHandler creates a new OTPHandler
func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
return &OTPHandler{
otpService: otpService,
}
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
Secret string `json:"secret"`
Algorithm string `json:"algorithm"`
Digits int `json:"digits"`
Period int `json:"period"`
}
// CreateOTP handles OTP creation
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Limit request body size
r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("CreateOTP unauthorized attempt")
return
}
// Parse request
var req CreateOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
return
}
// Validate OTP parameters
if req.Secret == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Secret is required")
log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
return
}
// Validate algorithm
supportedAlgos := map[string]bool{
"SHA1": true,
"SHA256": true,
"SHA512": true,
}
if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
userID, req.Algorithm)
return
}
// Create OTP
otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
Name: req.Name,
Issuer: req.Issuer,
Secret: req.Secret,
Algorithm: req.Algorithm,
Digits: req.Digits,
Period: req.Period,
})
if err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
log.Printf("CreateOTP failed for user %s: %v", userID, err)
return
}
// Log successful creation (mask secret in logs)
log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
api.NewResponseWriter(w).WriteSuccess(otp)
}
// ListOTPs handles listing all OTPs for a user
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTPs
otps, err := h.otpService.ListOTPs(r.Context(), userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(otps)
}
// GetOTPCode handles generating OTP code
func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
otpID = strings.TrimSuffix(otpID, "/code")
// Validate OTP ID format
if len(otpID) != 36 { // Assuming UUID format
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid OTP ID format")
log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
return
}
// Rate limiting check could be added here
// (would require redis or similar rate limiter)
// Generate code
code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
return
}
// Log successful generation (without actual code)
log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
userID, otpID, time.Since(start), expiresIn)
api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
"code": code,
"expires_in": expiresIn,
})
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code"`
}
// VerifyOTP handles OTP code verification
func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
otpID = strings.TrimSuffix(otpID, "/verify")
// Parse request
var req VerifyOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
return
}
// Verify code
valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": valid,
})
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
Algorithm string `json:"algorithm"`
Digits int `json:"digits"`
Period int `json:"period"`
}
// UpdateOTP handles OTP update
func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
// Parse request
var req UpdateOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
return
}
// Update OTP
otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
Name: req.Name,
Issuer: req.Issuer,
Algorithm: req.Algorithm,
Digits: req.Digits,
Period: req.Period,
})
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(otp)
}
// DeleteOTP handles OTP deletion
func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
// Delete OTP
if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(map[string]string{
"message": "OTP deleted successfully",
})
}
// Routes returns all routes for the OTP handler
func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/otp": h.CreateOTP,
"/otp/": h.ListOTPs,
"/otp/{id}": h.UpdateOTP,
"/otp/{id}/code": h.GetOTPCode,
"/otp/{id}/verify": h.VerifyOTP,
}
}

204
logger/logger.go Normal file
View file

@ -0,0 +1,204 @@
package logger
import (
"context"
"fmt"
"io"
"os"
"runtime"
"strings"
"time"
"github.com/google/uuid"
)
// Level represents a log level
type Level int
const (
// DEBUG level
DEBUG Level = iota
// INFO level
INFO
// WARN level
WARN
// ERROR level
ERROR
// FATAL level
FATAL
)
// String returns the string representation of the log level
func (l Level) String() string {
switch l {
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
case FATAL:
return "FATAL"
default:
return "UNKNOWN"
}
}
// Logger represents a logger
type Logger struct {
level Level
output io.Writer
}
// contextKey is a type for context keys
type contextKey string
// requestIDKey is the key for request ID in context
const requestIDKey = contextKey("request_id")
// New creates a new logger
func New(level Level, output io.Writer) *Logger {
if output == nil {
output = os.Stdout
}
return &Logger{
level: level,
output: output,
}
}
// WithLevel creates a new logger with the specified level
func (l *Logger) WithLevel(level Level) *Logger {
return &Logger{
level: level,
output: l.output,
}
}
// WithOutput creates a new logger with the specified output
func (l *Logger) WithOutput(output io.Writer) *Logger {
return &Logger{
level: l.level,
output: output,
}
}
// log logs a message with the specified level
func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) {
if level < l.level {
return
}
// Get request ID from context
requestID := getRequestID(ctx)
// Get caller information
_, file, line, ok := runtime.Caller(2)
if !ok {
file = "unknown"
line = 0
}
// Extract just the filename
if idx := strings.LastIndex(file, "/"); idx >= 0 {
file = file[idx+1:]
}
// Format message
message := fmt.Sprintf(format, args...)
// Format log entry
timestamp := time.Now().Format(time.RFC3339)
logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n",
timestamp, level.String(), file, line, requestID, message)
// Write log entry
_, _ = l.output.Write([]byte(logEntry))
// Exit if fatal
if level == FATAL {
os.Exit(1)
}
}
// Debug logs a debug message
func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, DEBUG, format, args...)
}
// Info logs an info message
func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, INFO, format, args...)
}
// Warn logs a warning message
func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, WARN, format, args...)
}
// Error logs an error message
func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, ERROR, format, args...)
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
l.log(ctx, FATAL, format, args...)
}
// WithRequestID adds a request ID to the context
func WithRequestID(ctx context.Context) context.Context {
requestID := uuid.New().String()
return context.WithValue(ctx, requestIDKey, requestID)
}
// GetRequestID gets the request ID from the context
func GetRequestID(ctx context.Context) string {
return getRequestID(ctx)
}
// getRequestID gets the request ID from the context
func getRequestID(ctx context.Context) string {
if ctx == nil {
return "-"
}
requestID, ok := ctx.Value(requestIDKey).(string)
if !ok {
return "-"
}
return requestID
}
// Default logger
var defaultLogger = New(INFO, os.Stdout)
// SetDefaultLogger sets the default logger
func SetDefaultLogger(logger *Logger) {
defaultLogger = logger
}
// Debug logs a debug message using the default logger
func Debug(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Debug(ctx, format, args...)
}
// Info logs an info message using the default logger
func Info(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Info(ctx, format, args...)
}
// Warn logs a warning message using the default logger
func Warn(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Warn(ctx, format, args...)
}
// Error logs an error message using the default logger
func Error(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Error(ctx, format, args...)
}
// Fatal logs a fatal message and exits using the default logger
func Fatal(ctx context.Context, format string, args ...interface{}) {
defaultLogger.Fatal(ctx, format, args...)
}

193
metrics/metrics.go Normal file
View file

@ -0,0 +1,193 @@
package metrics
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
// Default metrics
requestDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "Duration of HTTP requests in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
)
requestTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
)
otpGenerationTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "otp_generation_total",
Help: "Total number of OTP generations",
},
[]string{"user_id", "otp_id"},
)
otpVerificationTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "otp_verification_total",
Help: "Total number of OTP verifications",
},
[]string{"user_id", "otp_id", "success"},
)
activeUsers = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "active_users",
Help: "Number of active users",
},
)
cacheHits = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Help: "Total number of cache hits",
},
[]string{"cache"},
)
cacheMisses = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Help: "Total number of cache misses",
},
[]string{"cache"},
)
)
func init() {
// Register metrics with prometheus
prometheus.MustRegister(
requestDuration,
requestTotal,
otpGenerationTotal,
otpVerificationTotal,
activeUsers,
cacheHits,
cacheMisses,
)
}
// MetricsService provides metrics functionality
type MetricsService struct {
activeUsersMutex sync.RWMutex
activeUserIDs map[string]bool
}
// NewMetricsService creates a new MetricsService
func NewMetricsService() *MetricsService {
return &MetricsService{
activeUserIDs: make(map[string]bool),
}
}
// Handler returns an HTTP handler for metrics
func (s *MetricsService) Handler() http.Handler {
return promhttp.Handler()
}
// RecordRequest records metrics for an HTTP request
func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) {
labels := prometheus.Labels{
"method": method,
"path": path,
"status": fmt.Sprintf("%d", status),
}
requestDuration.With(labels).Observe(duration.Seconds())
requestTotal.With(labels).Inc()
}
// RecordOTPGeneration records metrics for OTP generation
func (s *MetricsService) RecordOTPGeneration(userID, otpID string) {
otpGenerationTotal.With(prometheus.Labels{
"user_id": userID,
"otp_id": otpID,
}).Inc()
}
// RecordOTPVerification records metrics for OTP verification
func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) {
otpVerificationTotal.With(prometheus.Labels{
"user_id": userID,
"otp_id": otpID,
"success": fmt.Sprintf("%t", success),
}).Inc()
}
// RecordUserActivity records user activity
func (s *MetricsService) RecordUserActivity(userID string) {
s.activeUsersMutex.Lock()
defer s.activeUsersMutex.Unlock()
if !s.activeUserIDs[userID] {
s.activeUserIDs[userID] = true
activeUsers.Inc()
}
}
// RecordUserInactivity records user inactivity
func (s *MetricsService) RecordUserInactivity(userID string) {
s.activeUsersMutex.Lock()
defer s.activeUsersMutex.Unlock()
if s.activeUserIDs[userID] {
delete(s.activeUserIDs, userID)
activeUsers.Dec()
}
}
// RecordCacheHit records a cache hit
func (s *MetricsService) RecordCacheHit(cache string) {
cacheHits.With(prometheus.Labels{
"cache": cache,
}).Inc()
}
// RecordCacheMiss records a cache miss
func (s *MetricsService) RecordCacheMiss(cache string) {
cacheMisses.With(prometheus.Labels{
"cache": cache,
}).Inc()
}
// Middleware creates a middleware that records request metrics
func (s *MetricsService) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Create response writer that captures status code
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
// Call next handler
next.ServeHTTP(rw, r)
// Record metrics
s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start))
})
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}

353
middleware/middleware.go Normal file
View file

@ -0,0 +1,353 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/golang-jwt/jwt"
)
// Response represents a standard API response
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ErrorResponse sends a JSON error response
func ErrorResponse(w http.ResponseWriter, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(Response{
Code: code,
Message: message,
})
}
// SuccessResponse sends a JSON success response
func SuccessResponse(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
Code: http.StatusOK,
Message: "success",
Data: data,
})
}
// Logger is a middleware that logs request details with structured format
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
r.Header.Set("X-Request-ID", requestID)
}
// Create a custom response writer to capture status code
rw := &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
}
// Process request
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
// Log structured request details
log.Printf(
"method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
r.Method,
r.URL.Path,
rw.status,
time.Since(start).String(),
r.RemoteAddr,
requestID,
)
})
}
// generateRequestID creates a unique request identifier
func generateRequestID() string {
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}
// randomString generates a random string of given length
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// Recover is a middleware that recovers from panics with detailed logging
func Recover(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// Get request ID from context
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log error with stack trace and request context
log.Printf(
"panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
err,
requestID,
r.Method,
r.URL.Path,
r.RemoteAddr,
debug.Stack(),
)
// Determine error type
var message string
var status int
switch e := err.(type) {
case error:
message = e.Error()
if isClientError(e) {
status = http.StatusBadRequest
} else {
status = http.StatusInternalServerError
}
case string:
message = e
status = http.StatusInternalServerError
default:
message = "Internal Server Error"
status = http.StatusInternalServerError
}
ErrorResponse(w, status, message)
}
}()
next.ServeHTTP(w, r)
})
}
// isClientError checks if error should be treated as client error
func isClientError(err error) bool {
// Add more client error types as needed
return strings.Contains(err.Error(), "validation") ||
strings.Contains(err.Error(), "invalid") ||
strings.Contains(err.Error(), "missing")
}
// CORS is a middleware that handles CORS
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// Timeout is a middleware that safely handles request timeouts
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), duration)
defer cancel()
// Use buffered channels to prevent goroutine leaks
done := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
// Track request processing in goroutine
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
next.ServeHTTP(w, r.WithContext(ctx))
done <- struct{}{}
}()
// Wait for completion, timeout or panic
select {
case <-done:
return
case p := <-panicChan:
panic(p) // Re-throw panic to be caught by Recover middleware
case <-ctx.Done():
// Get request context for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Log timeout details
log.Printf(
"request_timeout: request_id=%s method=%s path=%s timeout=%s",
requestID,
r.Method,
r.URL.Path,
duration.String(),
)
// Send timeout response
ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
"Request timed out after %s", duration.String(),
))
}
})
}
}
// Auth is a middleware that validates JWT tokens with enhanced security
func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get request ID for logging
requestID := ""
if ctx := r.Context(); ctx != nil {
if id, ok := ctx.Value("request_id").(string); ok {
requestID = id
}
}
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
return
}
// Validate header format
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer <token>'")
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtSecret), nil
})
if err != nil {
log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
if !token.Valid {
log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
return
}
// Validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
return
}
// Check required claims
userID, ok := claims["user_id"].(string)
if !ok || userID == "" {
log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
return
}
// Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
return
}
}
// Check required roles if specified
if len(requiredRoles) > 0 {
roles, ok := claims["roles"].([]interface{})
if !ok {
log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
return
}
hasRequiredRole := false
for _, requiredRole := range requiredRoles {
for _, role := range roles {
if r, ok := role.(string); ok && r == requiredRole {
hasRequiredRole = true
break
}
}
}
if !hasRequiredRole {
log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
return
}
}
// Add claims to context
ctx := r.Context()
ctx = context.WithValue(ctx, "user_id", userID)
ctx = context.WithValue(ctx, "claims", claims)
log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// responseWriter is a custom response writer that captures the status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// GetUserID gets the user ID from the request context
func GetUserID(r *http.Request) (string, error) {
userID, ok := r.Context().Value("user_id").(string)
if !ok {
return "", fmt.Errorf("user ID not found in context")
}
return userID, nil
}

View file

@ -0,0 +1,50 @@
// app.js
App({
onLaunch() {
// 检查更新
if (wx.canIUse('getUpdateManager')) {
const updateManager = wx.getUpdateManager();
updateManager.onCheckForUpdate(function (res) {
if (res.hasUpdate) {
updateManager.onUpdateReady(function () {
wx.showModal({
title: '更新提示',
content: '新版本已经准备好,是否重启应用?',
success: function (res) {
if (res.confirm) {
updateManager.applyUpdate();
}
}
});
});
updateManager.onUpdateFailed(function () {
wx.showModal({
title: '更新提示',
content: '新版本下载失败,请检查网络后重试',
showCancel: false
});
});
}
});
}
// 获取系统信息
try {
const systemInfo = wx.getSystemInfoSync();
this.globalData.systemInfo = systemInfo;
// 计算安全区域
const { screenHeight, safeArea } = systemInfo;
this.globalData.safeAreaBottom = screenHeight - safeArea.bottom;
} catch (e) {
console.error('获取系统信息失败', e);
}
},
globalData: {
userInfo: null,
systemInfo: {},
safeAreaBottom: 0
}
});

View file

@ -0,0 +1,23 @@
{
"pages": [
"pages/login/login",
"pages/otp-list/index",
"pages/otp-add/index"
],
"window": {
"backgroundTextStyle": "light",
"navigationBarBackgroundColor": "#fff",
"navigationBarTitleText": "OTPM",
"navigationBarTextStyle": "black",
"backgroundColor": "#F8F8F8"
},
"permission": {
"scope.camera": {
"desc": "需要使用相机扫描二维码"
}
},
"usingComponents": {},
"style": "v2",
"sitemapLocation": "sitemap.json",
"lazyCodeLoading": "requiredComponents"
}

View file

@ -0,0 +1,238 @@
/**app.wxss**/
page {
--primary-color: #1890ff;
--danger-color: #ff4d4f;
--success-color: #52c41a;
--warning-color: #faad14;
--text-color: #333333;
--text-color-secondary: #666666;
--text-color-light: #999999;
--border-color: #e8e8e8;
--background-color: #f8f8f8;
--border-radius: 8rpx;
--safe-area-bottom: env(safe-area-inset-bottom);
font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica,
Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei',
sans-serif;
font-size: 28rpx;
line-height: 1.5;
color: var(--text-color);
background-color: var(--background-color);
}
/* 清除默认样式 */
button {
padding: 0;
margin: 0;
background: none;
border: none;
text-align: left;
line-height: inherit;
overflow: visible;
}
button::after {
border: none;
}
/* 通用样式类 */
.container {
min-height: 100vh;
box-sizing: border-box;
}
.safe-area-bottom {
padding-bottom: var(--safe-area-bottom);
}
.flex-center {
display: flex;
align-items: center;
justify-content: center;
}
.flex-between {
display: flex;
align-items: center;
justify-content: space-between;
}
.flex-column {
display: flex;
flex-direction: column;
}
.text-primary {
color: var(--primary-color);
}
.text-danger {
color: var(--danger-color);
}
.text-success {
color: var(--success-color);
}
.text-warning {
color: var(--warning-color);
}
.text-secondary {
color: var(--text-color-secondary);
}
.text-light {
color: var(--text-color-light);
}
.text-center {
text-align: center;
}
.text-left {
text-align: left;
}
.text-right {
text-align: right;
}
.text-bold {
font-weight: bold;
}
.text-ellipsis {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.bg-white {
background-color: #ffffff;
}
.bg-primary {
background-color: var(--primary-color);
}
.bg-danger {
background-color: var(--danger-color);
}
.bg-success {
background-color: var(--success-color);
}
.bg-warning {
background-color: var(--warning-color);
}
.shadow {
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.rounded {
border-radius: var(--border-radius);
}
.border {
border: 2rpx solid var(--border-color);
}
.border-top {
border-top: 2rpx solid var(--border-color);
}
.border-bottom {
border-bottom: 2rpx solid var(--border-color);
}
/* 动画类 */
.fade-in {
animation: fadeIn 0.3s ease-in-out;
}
.fade-out {
animation: fadeOut 0.3s ease-in-out;
}
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
@keyframes fadeOut {
from {
opacity: 1;
}
to {
opacity: 0;
}
}
/* 间距类 */
.m-0 { margin: 0; }
.m-1 { margin: 10rpx; }
.m-2 { margin: 20rpx; }
.m-3 { margin: 30rpx; }
.m-4 { margin: 40rpx; }
.mt-0 { margin-top: 0; }
.mt-1 { margin-top: 10rpx; }
.mt-2 { margin-top: 20rpx; }
.mt-3 { margin-top: 30rpx; }
.mt-4 { margin-top: 40rpx; }
.mb-0 { margin-bottom: 0; }
.mb-1 { margin-bottom: 10rpx; }
.mb-2 { margin-bottom: 20rpx; }
.mb-3 { margin-bottom: 30rpx; }
.mb-4 { margin-bottom: 40rpx; }
.ml-0 { margin-left: 0; }
.ml-1 { margin-left: 10rpx; }
.ml-2 { margin-left: 20rpx; }
.ml-3 { margin-left: 30rpx; }
.ml-4 { margin-left: 40rpx; }
.mr-0 { margin-right: 0; }
.mr-1 { margin-right: 10rpx; }
.mr-2 { margin-right: 20rpx; }
.mr-3 { margin-right: 30rpx; }
.mr-4 { margin-right: 40rpx; }
.p-0 { padding: 0; }
.p-1 { padding: 10rpx; }
.p-2 { padding: 20rpx; }
.p-3 { padding: 30rpx; }
.p-4 { padding: 40rpx; }
.pt-0 { padding-top: 0; }
.pt-1 { padding-top: 10rpx; }
.pt-2 { padding-top: 20rpx; }
.pt-3 { padding-top: 30rpx; }
.pt-4 { padding-top: 40rpx; }
.pb-0 { padding-bottom: 0; }
.pb-1 { padding-bottom: 10rpx; }
.pb-2 { padding-bottom: 20rpx; }
.pb-3 { padding-bottom: 30rpx; }
.pb-4 { padding-bottom: 40rpx; }
.pl-0 { padding-left: 0; }
.pl-1 { padding-left: 10rpx; }
.pl-2 { padding-left: 20rpx; }
.pl-3 { padding-left: 30rpx; }
.pl-4 { padding-left: 40rpx; }
.pr-0 { padding-right: 0; }
.pr-1 { padding-right: 10rpx; }
.pr-2 { padding-right: 20rpx; }
.pr-3 { padding-right: 30rpx; }
.pr-4 { padding-right: 40rpx; }

View file

@ -0,0 +1,48 @@
// login.js
import { wxLogin } from '../../services/auth';
Page({
data: {
loading: false
},
onLoad() {
// 页面加载时检查是否已经登录
const token = wx.getStorageSync('token');
if (token) {
this.redirectToHome();
}
},
// 处理登录按钮点击
handleLogin() {
if (this.data.loading) return;
this.setData({ loading: true });
wxLogin()
.then(() => {
wx.showToast({
title: '登录成功',
icon: 'success'
});
this.redirectToHome();
})
.catch(err => {
wx.showToast({
title: err.message || '登录失败',
icon: 'none'
});
})
.finally(() => {
this.setData({ loading: false });
});
},
// 跳转到首页
redirectToHome() {
wx.reLaunch({
url: '/pages/otp-list/index'
});
}
});

View file

@ -0,0 +1,3 @@
{
"usingComponents": {}
}

View file

@ -0,0 +1,30 @@
<!-- login.wxml -->
<view class="container">
<view class="logo-container">
<image class="logo" src="/assets/images/logo.png" mode="aspectFit"></image>
<text class="app-name">OTPM 小程序</text>
</view>
<view class="login-container">
<text class="login-title">欢迎使用 OTPM</text>
<text class="login-subtitle">一次性密码管理工具</text>
<button
class="login-button {{loading ? 'loading' : ''}}"
type="primary"
bindtap="handleLogin"
disabled="{{loading}}"
>
<text wx:if="{{!loading}}">微信一键登录</text>
<view wx:else class="loading-container">
<view class="loading-icon"></view>
<text>登录中...</text>
</view>
</button>
<view class="privacy-policy">
<text>登录即表示您同意</text>
<navigator url="/pages/privacy/index" class="policy-link">《隐私政策》</navigator>
</view>
</view>
</view>

View file

@ -0,0 +1,97 @@
/* login.wxss */
.container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-between;
height: 100vh;
padding: 60rpx 40rpx;
box-sizing: border-box;
background-color: #f8f8f8;
}
.logo-container {
display: flex;
flex-direction: column;
align-items: center;
margin-top: 80rpx;
}
.logo {
width: 180rpx;
height: 180rpx;
margin-bottom: 20rpx;
}
.app-name {
font-size: 36rpx;
font-weight: bold;
color: #333;
}
.login-container {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
margin-bottom: 100rpx;
}
.login-title {
font-size: 48rpx;
font-weight: bold;
color: #333;
margin-bottom: 20rpx;
}
.login-subtitle {
font-size: 28rpx;
color: #666;
margin-bottom: 80rpx;
}
.login-button {
width: 80%;
height: 88rpx;
border-radius: 44rpx;
font-size: 32rpx;
display: flex;
align-items: center;
justify-content: center;
}
.login-button.loading {
background-color: #8cc4ff;
}
.loading-container {
display: flex;
align-items: center;
justify-content: center;
}
.loading-icon {
width: 36rpx;
height: 36rpx;
margin-right: 10rpx;
border: 4rpx solid #ffffff;
border-radius: 50%;
border-top-color: transparent;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.privacy-policy {
margin-top: 40rpx;
font-size: 24rpx;
color: #999;
}
.policy-link {
color: #1890ff;
display: inline;
}

View file

@ -0,0 +1,169 @@
// otp-add/index.js
import { createOTP } from '../../services/otp';
Page({
data: {
form: {
name: '',
issuer: '',
secret: '',
algorithm: 'SHA1',
digits: 6,
period: 30
},
algorithms: ['SHA1', 'SHA256', 'SHA512'],
digitOptions: [6, 8],
periodOptions: [30, 60],
submitting: false,
scanMode: false
},
// 处理输入变化
handleInputChange(e) {
const { field } = e.currentTarget.dataset;
const { value } = e.detail;
this.setData({
[`form.${field}`]: value
});
},
// 处理选择器变化
handlePickerChange(e) {
const { field } = e.currentTarget.dataset;
const { value } = e.detail;
const options = this.data[`${field}Options`] || this.data[field];
const selectedValue = options[value];
this.setData({
[`form.${field}`]: selectedValue
});
},
// 扫描二维码
handleScanQRCode() {
this.setData({ scanMode: true });
wx.scanCode({
scanType: ['qrCode'],
success: (res) => {
try {
// 解析otpauth://协议的URL
const url = res.result;
if (url.startsWith('otpauth://totp/')) {
const parsedUrl = new URL(url);
const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠
// 解析路径中的issuer和name
let issuer = '';
let name = path;
if (path.includes(':')) {
const parts = path.split(':');
issuer = parts[0];
name = parts[1];
}
// 从查询参数中获取其他信息
const secret = parsedUrl.searchParams.get('secret') || '';
const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1';
const digits = parseInt(parsedUrl.searchParams.get('digits') || '6');
const period = parseInt(parsedUrl.searchParams.get('period') || '30');
// 如果查询参数中有issuer优先使用
if (parsedUrl.searchParams.get('issuer')) {
issuer = parsedUrl.searchParams.get('issuer');
}
this.setData({
form: {
name,
issuer,
secret,
algorithm,
digits,
period
}
});
wx.showToast({
title: '二维码解析成功',
icon: 'success'
});
} else {
wx.showToast({
title: '不支持的二维码格式',
icon: 'none'
});
}
} catch (err) {
wx.showToast({
title: '二维码解析失败',
icon: 'none'
});
}
},
fail: () => {
wx.showToast({
title: '扫描取消',
icon: 'none'
});
},
complete: () => {
this.setData({ scanMode: false });
}
});
},
// 提交表单
handleSubmit() {
const { form } = this.data;
// 表单验证
if (!form.name) {
wx.showToast({
title: '请输入名称',
icon: 'none'
});
return;
}
if (!form.secret) {
wx.showToast({
title: '请输入密钥',
icon: 'none'
});
return;
}
this.setData({ submitting: true });
createOTP(form)
.then(() => {
wx.showToast({
title: '添加成功',
icon: 'success'
});
// 返回上一页
setTimeout(() => {
wx.navigateBack();
}, 1500);
})
.catch(err => {
wx.showToast({
title: err.message || '添加失败',
icon: 'none'
});
})
.finally(() => {
this.setData({ submitting: false });
});
},
// 取消
handleCancel() {
wx.navigateBack();
}
});

View file

@ -0,0 +1,3 @@
{
"usingComponents": {}
}

View file

@ -0,0 +1,119 @@
<!-- otp-add/index.wxml -->
<view class="container">
<view class="header">
<text class="title">添加OTP</text>
</view>
<view class="form-container">
<view class="form-group">
<text class="form-label">名称 <text class="required">*</text></text>
<input
class="form-input"
placeholder="请输入OTP名称"
value="{{form.name}}"
bindinput="handleInputChange"
data-field="name"
/>
</view>
<view class="form-group">
<text class="form-label">发行方</text>
<input
class="form-input"
placeholder="请输入发行方名称"
value="{{form.issuer}}"
bindinput="handleInputChange"
data-field="issuer"
/>
</view>
<view class="form-group">
<text class="form-label">密钥 <text class="required">*</text></text>
<view class="secret-input-container">
<input
class="form-input"
placeholder="请输入密钥或扫描二维码"
value="{{form.secret}}"
bindinput="handleInputChange"
data-field="secret"
/>
<view class="scan-button" bindtap="handleScanQRCode" wx:if="{{!scanMode}}">
<text class="scan-icon">🔍</text>
</view>
<view class="scanning-indicator" wx:else>
<view class="scanning-spinner"></view>
</view>
</view>
</view>
<view class="form-group">
<text class="form-label">算法</text>
<picker
mode="selector"
range="{{algorithms}}"
value="{{algorithms.indexOf(form.algorithm)}}"
bindchange="handlePickerChange"
data-field="algorithm"
>
<view class="picker-view">
<text>{{form.algorithm}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
<view class="form-row">
<view class="form-group half">
<text class="form-label">位数</text>
<picker
mode="selector"
range="{{digitOptions}}"
value="{{digitOptions.indexOf(form.digits)}}"
bindchange="handlePickerChange"
data-field="digits"
>
<view class="picker-view">
<text>{{form.digits}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
<view class="form-group half">
<text class="form-label">周期(秒)</text>
<picker
mode="selector"
range="{{periodOptions}}"
value="{{periodOptions.indexOf(form.period)}}"
bindchange="handlePickerChange"
data-field="period"
>
<view class="picker-view">
<text>{{form.period}}</text>
<text class="picker-arrow">▼</text>
</view>
</picker>
</view>
</view>
</view>
<view class="button-group">
<button
class="cancel-button"
bindtap="handleCancel"
disabled="{{submitting}}"
>取消</button>
<button
class="submit-button {{submitting ? 'loading' : ''}}"
bindtap="handleSubmit"
disabled="{{submitting}}"
>
<text wx:if="{{!submitting}}">保存</text>
<view wx:else class="loading-container">
<view class="loading-icon"></view>
<text>保存中...</text>
</view>
</button>
</view>
</view>

View file

@ -0,0 +1,176 @@
/* otp-add/index.wxss */
.container {
min-height: 100vh;
background-color: #f8f8f8;
padding: 0 0 40rpx 0;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 40rpx 32rpx;
background-color: #ffffff;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
}
.title {
font-size: 36rpx;
font-weight: bold;
color: #333333;
}
.form-container {
background-color: #ffffff;
padding: 32rpx;
margin: 32rpx;
border-radius: 16rpx;
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.form-group {
margin-bottom: 32rpx;
}
.form-row {
display: flex;
justify-content: space-between;
}
.form-group.half {
width: 48%;
}
.form-label {
display: block;
font-size: 28rpx;
color: #666666;
margin-bottom: 12rpx;
}
.required {
color: #ff4d4f;
}
.form-input {
width: 100%;
height: 80rpx;
background-color: #f5f5f5;
border-radius: 8rpx;
padding: 0 24rpx;
font-size: 28rpx;
color: #333333;
box-sizing: border-box;
}
.secret-input-container {
position: relative;
}
.scan-button {
position: absolute;
right: 20rpx;
top: 50%;
transform: translateY(-50%);
width: 60rpx;
height: 60rpx;
display: flex;
align-items: center;
justify-content: center;
}
.scan-icon {
font-size: 40rpx;
color: #1890ff;
}
.scanning-indicator {
position: absolute;
right: 20rpx;
top: 50%;
transform: translateY(-50%);
width: 40rpx;
height: 40rpx;
}
.scanning-spinner {
width: 40rpx;
height: 40rpx;
border: 4rpx solid #f3f3f3;
border-top: 4rpx solid #1890ff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.picker-view {
width: 100%;
height: 80rpx;
background-color: #f5f5f5;
border-radius: 8rpx;
padding: 0 24rpx;
font-size: 28rpx;
color: #333333;
display: flex;
align-items: center;
justify-content: space-between;
box-sizing: border-box;
}
.picker-arrow {
font-size: 24rpx;
color: #999999;
}
.button-group {
display: flex;
justify-content: space-between;
padding: 32rpx;
}
.cancel-button, .submit-button {
width: 48%;
height: 88rpx;
border-radius: 44rpx;
font-size: 32rpx;
display: flex;
align-items: center;
justify-content: center;
}
.cancel-button {
background-color: #f5f5f5;
color: #666666;
}
.submit-button {
background-color: #1890ff;
color: #ffffff;
}
.submit-button.loading {
background-color: #8cc4ff;
}
.loading-container {
display: flex;
align-items: center;
justify-content: center;
}
.loading-icon {
width: 36rpx;
height: 36rpx;
margin-right: 10rpx;
border: 4rpx solid #ffffff;
border-radius: 50%;
border-top-color: transparent;
animation: spin 1s linear infinite;
}

View file

@ -0,0 +1,213 @@
// otp-list/index.js
import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp';
import { checkLoginStatus } from '../../services/auth';
Page({
data: {
otpList: [],
loading: true,
refreshing: false
},
onLoad() {
this.checkLogin();
},
onShow() {
// 每次页面显示时刷新OTP列表
if (!this.data.loading) {
this.fetchOTPList();
}
},
// 下拉刷新
onPullDownRefresh() {
this.setData({ refreshing: true });
this.fetchOTPList().finally(() => {
wx.stopPullDownRefresh();
this.setData({ refreshing: false });
});
},
// 检查登录状态
checkLogin() {
checkLoginStatus().then(isLoggedIn => {
if (isLoggedIn) {
this.fetchOTPList();
} else {
wx.redirectTo({
url: '/pages/login/login'
});
}
});
},
// 获取OTP列表
fetchOTPList() {
this.setData({ loading: true });
return getOTPList()
.then(res => {
if (res.data && Array.isArray(res.data)) {
this.setData({
otpList: res.data,
loading: false
});
// 获取每个OTP的当前验证码
this.refreshOTPCodes();
}
})
.catch(err => {
wx.showToast({
title: '获取OTP列表失败',
icon: 'none'
});
this.setData({ loading: false });
});
},
// 刷新所有OTP的验证码
refreshOTPCodes() {
const { otpList } = this.data;
// 为每个OTP获取当前验证码
const promises = otpList.map(otp => {
return getOTPCode(otp.id)
.then(res => {
if (res.data && res.data.code) {
return {
id: otp.id,
code: res.data.code,
expiresIn: res.data.expires_in || 30
};
}
return null;
})
.catch(() => null);
});
Promise.all(promises).then(results => {
const updatedList = [...this.data.otpList];
results.forEach(result => {
if (result) {
const index = updatedList.findIndex(otp => otp.id === result.id);
if (index !== -1) {
updatedList[index] = {
...updatedList[index],
currentCode: result.code,
expiresIn: result.expiresIn
};
}
}
});
this.setData({ otpList: updatedList });
// 设置定时器,每秒更新倒计时
this.startCountdown();
});
},
// 开始倒计时
startCountdown() {
// 清除之前的定时器
if (this.countdownTimer) {
clearInterval(this.countdownTimer);
}
// 创建新的定时器,每秒更新一次
this.countdownTimer = setInterval(() => {
const { otpList } = this.data;
let needRefresh = false;
const updatedList = otpList.map(otp => {
if (!otp.countdown) {
otp.countdown = otp.expiresIn || 30;
}
otp.countdown -= 1;
// 如果倒计时结束,标记需要刷新
if (otp.countdown <= 0) {
needRefresh = true;
}
return otp;
});
this.setData({ otpList: updatedList });
// 如果有OTP需要刷新重新获取验证码
if (needRefresh) {
this.refreshOTPCodes();
}
}, 1000);
},
// 添加新的OTP
handleAddOTP() {
wx.navigateTo({
url: '/pages/otp-add/index'
});
},
// 编辑OTP
handleEditOTP(e) {
const { id } = e.currentTarget.dataset;
wx.navigateTo({
url: `/pages/otp-edit/index?id=${id}`
});
},
// 删除OTP
handleDeleteOTP(e) {
const { id, name } = e.currentTarget.dataset;
wx.showModal({
title: '确认删除',
content: `确定要删除 ${name} 吗?`,
confirmColor: '#ff4d4f',
success: (res) => {
if (res.confirm) {
deleteOTP(id)
.then(() => {
wx.showToast({
title: '删除成功',
icon: 'success'
});
this.fetchOTPList();
})
.catch(err => {
wx.showToast({
title: '删除失败',
icon: 'none'
});
});
}
}
});
},
// 复制验证码
handleCopyCode(e) {
const { code } = e.currentTarget.dataset;
wx.setClipboardData({
data: code,
success: () => {
wx.showToast({
title: '验证码已复制',
icon: 'success'
});
}
});
},
onUnload() {
// 页面卸载时清除定时器
if (this.countdownTimer) {
clearInterval(this.countdownTimer);
}
}
});

View file

@ -0,0 +1,3 @@
{
"usingComponents": {}
}

View file

@ -0,0 +1,59 @@
<!-- otp-list/index.wxml -->
<view class="container">
<view class="header">
<text class="title">我的OTP列表</text>
<view class="add-button" bindtap="handleAddOTP">
<text class="add-icon">+</text>
</view>
</view>
<!-- 加载中 -->
<view class="loading-container" wx:if="{{loading}}">
<view class="loading-spinner"></view>
<text class="loading-text">加载中...</text>
</view>
<!-- OTP列表 -->
<view class="otp-list" wx:else>
<block wx:if="{{otpList.length > 0}}">
<view class="otp-item" wx:for="{{otpList}}" wx:key="id">
<view class="otp-info">
<view class="otp-name-row">
<text class="otp-name">{{item.name}}</text>
<text class="otp-issuer">{{item.issuer}}</text>
</view>
<view class="otp-code-row" bindtap="handleCopyCode" data-code="{{item.currentCode}}">
<text class="otp-code">{{item.currentCode || '******'}}</text>
<text class="copy-hint">点击复制</text>
</view>
<view class="otp-countdown">
<progress
percent="{{(item.countdown / item.expiresIn) * 100}}"
stroke-width="3"
activeColor="{{item.countdown < 10 ? '#ff4d4f' : '#1890ff'}}"
backgroundColor="#e9e9e9"
/>
<text class="countdown-text">{{item.countdown || 0}}s</text>
</view>
</view>
<view class="otp-actions">
<view class="action-button edit" bindtap="handleEditOTP" data-id="{{item.id}}">
<text class="action-icon">✎</text>
</view>
<view class="action-button delete" bindtap="handleDeleteOTP" data-id="{{item.id}}" data-name="{{item.name}}">
<text class="action-icon">✕</text>
</view>
</view>
</view>
</block>
<!-- 空状态 -->
<view class="empty-state" wx:else>
<image class="empty-image" src="/assets/images/empty.png" mode="aspectFit"></image>
<text class="empty-text">暂无OTP点击右上角添加</text>
</view>
</view>
</view>

View file

@ -0,0 +1,201 @@
/* otp-list/index.wxss */
.container {
min-height: 100vh;
background-color: #f8f8f8;
padding: 0 0 40rpx 0;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 40rpx 32rpx;
background-color: #ffffff;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
}
.title {
font-size: 36rpx;
font-weight: bold;
color: #333333;
}
.add-button {
width: 64rpx;
height: 64rpx;
border-radius: 32rpx;
background-color: #1890ff;
display: flex;
align-items: center;
justify-content: center;
box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3);
}
.add-icon {
color: #ffffff;
font-size: 40rpx;
line-height: 1;
}
/* 加载状态 */
.loading-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 120rpx 0;
}
.loading-spinner {
width: 64rpx;
height: 64rpx;
border: 6rpx solid #f3f3f3;
border-top: 6rpx solid #1890ff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
.loading-text {
margin-top: 20rpx;
font-size: 28rpx;
color: #999999;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* OTP列表 */
.otp-list {
padding: 20rpx 32rpx;
}
.otp-item {
background-color: #ffffff;
border-radius: 16rpx;
padding: 32rpx;
margin-bottom: 20rpx;
display: flex;
justify-content: space-between;
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
}
.otp-info {
flex: 1;
margin-right: 20rpx;
}
.otp-name-row {
display: flex;
align-items: center;
margin-bottom: 16rpx;
}
.otp-name {
font-size: 32rpx;
font-weight: bold;
color: #333333;
margin-right: 16rpx;
}
.otp-issuer {
font-size: 24rpx;
color: #666666;
background-color: #f5f5f5;
padding: 4rpx 12rpx;
border-radius: 8rpx;
}
.otp-code-row {
display: flex;
align-items: center;
margin-bottom: 20rpx;
}
.otp-code {
font-size: 44rpx;
font-family: monospace;
font-weight: bold;
color: #1890ff;
letter-spacing: 4rpx;
margin-right: 16rpx;
}
.copy-hint {
font-size: 24rpx;
color: #999999;
}
.otp-countdown {
position: relative;
width: 100%;
}
.countdown-text {
position: absolute;
right: 0;
top: -30rpx;
font-size: 24rpx;
color: #999999;
}
/* OTP操作按钮 */
.otp-actions {
display: flex;
flex-direction: column;
justify-content: space-between;
}
.action-button {
width: 56rpx;
height: 56rpx;
border-radius: 28rpx;
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 16rpx;
}
.action-button.edit {
background-color: #f0f7ff;
}
.action-button.delete {
background-color: #fff1f0;
}
.action-icon {
font-size: 32rpx;
}
.edit .action-icon {
color: #1890ff;
}
.delete .action-icon {
color: #ff4d4f;
}
/* 空状态 */
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 120rpx 0;
}
.empty-image {
width: 240rpx;
height: 240rpx;
margin-bottom: 32rpx;
}
.empty-text {
font-size: 28rpx;
color: #999999;
}

View file

@ -0,0 +1,47 @@
{
"description": "项目配置文件",
"packOptions": {
"ignore": [],
"include": []
},
"miniprogramRoot": "",
"compileType": "miniprogram",
"projectname": "OTPM",
"setting": {
"useCompilerPlugins": [
"sass"
],
"babelSetting": {
"ignore": [],
"disablePlugins": [],
"outputPath": ""
},
"es6": true,
"enhance": true,
"minified": true,
"postcss": true,
"minifyWXSS": true,
"minifyWXML": true,
"uglifyFileName": true,
"packNpmManually": false,
"packNpmRelationList": [],
"ignoreUploadUnusedFiles": true,
"compileWorklet": false,
"uploadWithSourceMap": true,
"localPlugins": false,
"disableUseStrict": false,
"condition": false,
"swc": false,
"disableSWC": true
},
"simulatorType": "wechat",
"simulatorPluginLibVersion": {},
"condition": {},
"srcMiniprogramRoot": "",
"appid": "wxb6599459668b6b55",
"libVersion": "2.30.2",
"editorSetting": {
"tabIndent": "insertSpaces",
"tabSize": 2
}
}

View file

@ -0,0 +1,23 @@
{
"libVersion": "3.8.5",
"projectname": "OTPM",
"condition": {},
"setting": {
"urlCheck": true,
"coverView": false,
"lazyloadPlaceholderEnable": false,
"skylineRenderEnable": false,
"preloadBackgroundData": false,
"autoAudits": false,
"useApiHook": true,
"useApiHostProcess": true,
"showShadowRootInWxmlPanel": false,
"useStaticServer": false,
"useLanDebug": false,
"showES6CompileOption": false,
"compileHotReLoad": true,
"checkInvalidKey": true,
"ignoreDevUnusedFiles": true,
"bigPackageSizeSupport": false
}
}

View file

@ -0,0 +1,84 @@
// auth.js - 认证相关服务
import request from '../utils/request';
/**
* 微信登录
* 1. 调用wx.login获取code
* 2. 发送code到服务端换取token和openid
* 3. 保存token和openid到本地存储
*/
export const wxLogin = () => {
return new Promise((resolve, reject) => {
wx.login({
success: (res) => {
if (res.code) {
// 发送code到服务端
request({
url: '/login',
method: 'POST',
data: {
code: res.code
}
}).then(response => {
// 保存token和openid
if (response.data && response.data.token && response.data.openid) {
wx.setStorageSync('token', response.data.token);
wx.setStorageSync('openid', response.data.openid);
resolve(response.data);
} else {
reject(new Error('登录失败,服务器返回数据格式错误'));
}
}).catch(err => {
reject(err);
});
} else {
reject(new Error('登录失败获取code失败: ' + res.errMsg));
}
},
fail: (err) => {
reject(new Error('微信登录失败: ' + err.errMsg));
}
});
});
};
/**
* 检查登录状态
* 1. 检查本地是否有token和openid
* 2. 如果有验证token是否有效
* 3. 如果无效清除本地存储并返回false
*/
export const checkLoginStatus = () => {
return new Promise((resolve, reject) => {
const token = wx.getStorageSync('token');
const openid = wx.getStorageSync('openid');
if (!token || !openid) {
resolve(false);
return;
}
// 验证token有效性
request({
url: '/verify-token',
method: 'POST'
}).then(() => {
resolve(true);
}).catch(() => {
// token无效清除本地存储
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
resolve(false);
});
});
};
/**
* 退出登录
*/
export const logout = () => {
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
return Promise.resolve();
};

View file

@ -0,0 +1,119 @@
// otp.js - OTP相关服务
import request from '../utils/request';
/**
* 创建新的OTP
* @param {Object} params - 创建OTP的参数
* @param {string} params.name - OTP名称
* @param {string} params.issuer - 发行方
* @param {string} params.secret - 密钥
* @param {string} params.algorithm - 算法默认为SHA1
* @param {number} params.digits - 位数默认为6
* @param {number} params.period - 周期默认为30秒
* @returns {Promise} - 返回创建结果
*/
export const createOTP = (params) => {
if (!params || !params.secret) {
return Promise.reject(new Error('缺少必要的参数: secret'));
}
return request({
url: '/otp',
method: 'POST',
data: {
name: params.name || '',
issuer: params.issuer || '',
secret: params.secret,
algorithm: params.algorithm || 'SHA1',
digits: params.digits || 6,
period: params.period || 30
}
}).catch(err => {
console.error('创建OTP失败:', err);
throw new Error('创建OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 获取用户所有OTP列表
* @returns {Promise} - 返回OTP列表
*/
export const getOTPList = () => {
return request({
url: '/otp',
method: 'GET'
}).catch(err => {
console.error('获取OTP列表失败:', err);
throw new Error('获取OTP列表失败: ' + (err.message || '未知错误'));
});
};
/**
* 获取指定OTP的当前验证码
* @param {string} id - OTP的ID
* @returns {Promise} - 返回当前验证码
*/
export const getOTPCode = (id) => {
if (!id) {
return Promise.reject(new Error('缺少必要的参数: id'));
}
return request({
url: `/otp/${id}/code`,
method: 'GET'
}).catch(err => {
console.error('获取OTP代码失败:', err);
throw new Error('获取OTP代码失败: ' + (err.message || '未知错误'));
});
};
/**
* 验证OTP
* @param {string} id - OTP的ID
* @param {string} code - 用户输入的验证码
* @returns {Promise} - 返回验证结果
*/
export const verifyOTP = (id, code) => {
if (!id || !code) {
return Promise.reject(new Error('缺少必要的参数: id或code'));
}
return request({
url: `/otp/${id}/verify`,
method: 'POST',
data: { code }
}).catch(err => {
console.error('验证OTP失败:', err);
throw new Error('验证OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 更新OTP信息
* @param {string} id - OTP的ID
* @param {Object} params - 更新的参数
* @returns {Promise} - 返回更新结果
*/
export const updateOTP = (id, params) => {
if (!id || !params) {
return Promise.reject(new Error('缺少必要的参数: id或params'));
}
return request({
url: `/otp/${id}`,
method: 'PUT',
data: params
}).catch(err => {
console.error('更新OTP失败:', err);
throw new Error('更新OTP失败: ' + (err.message || '未知错误'));
});
};
/**
* 删除OTP
* @param {string} id - OTP的ID
* @returns {Promise} - 返回删除结果
*/
export const deleteOTP = (id) => {
return request({
url: `/otp/${id}`,
method: 'DELETE'
});
};

View file

@ -0,0 +1,58 @@
// request.js - 网络请求工具类
const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名
// 请求拦截器
const request = (options) => {
return new Promise((resolve, reject) => {
const token = wx.getStorageSync('token');
const header = {
'Content-Type': 'application/json',
...options.header
};
// 如果有token添加到请求头
if (token) {
header['Authorization'] = `Bearer ${token}`;
}
wx.request({
url: `${BASE_URL}${options.url}`,
method: options.method || 'GET',
data: options.data,
header: header,
success: (res) => {
// 处理业务错误
if (res.data.code !== 0) {
// token过期直接清除并跳转登录
if (res.statusCode === 401) {
wx.removeStorageSync('token');
wx.removeStorageSync('openid');
reject(new Error('登录已过期,请重新登录'));
return;
}
reject(new Error(res.data.message || '请求失败'));
return;
}
resolve(res.data);
},
fail: reject
});
});
};
// 刷新token
const refreshToken = () => {
return request({
url: '/refresh-token',
method: 'POST'
}).then(res => {
if (res.data && res.data.token) {
wx.setStorageSync('token', res.data.token);
return res.data.token;
}
throw new Error('Failed to refresh token');
});
};
export default request;

195
models/otp.go Normal file
View file

@ -0,0 +1,195 @@
package models
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/jmoiron/sqlx"
)
// OTP represents a TOTP configuration
type OTP struct {
ID string `db:"id" json:"id"`
UserID string `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Issuer string `db:"issuer" json:"issuer"`
Secret string `db:"secret" json:"-"` // Never expose secret in JSON
Algorithm string `db:"algorithm" json:"algorithm"`
Digits int `db:"digits" json:"digits"`
Period int `db:"period" json:"period"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// OTPParams represents common OTP parameters used in creation and update
type OTPParams struct {
Name string
Issuer string
Secret string
Algorithm string
Digits int
Period int
}
// OTPRepository handles OTP data operations
type OTPRepository struct {
db *sqlx.DB
}
// NewOTPRepository creates a new OTPRepository
func NewOTPRepository(db *sqlx.DB) *OTPRepository {
return &OTPRepository{db: db}
}
// FindByID finds an OTP by ID and user ID
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
var otp OTP
query := `SELECT * FROM otps WHERE id = ? AND user_id = ?`
err := r.db.GetContext(ctx, &otp, query, id, userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("otp not found: %w", err)
}
return nil, fmt.Errorf("failed to find otp: %w", err)
}
return &otp, nil
}
// FindAllByUserID finds all OTPs for a user
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
var otps []*OTP
query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC`
err := r.db.SelectContext(ctx, &otps, query, userID)
if err != nil {
return nil, fmt.Errorf("failed to find otps: %w", err)
}
return otps, nil
}
// Create creates a new OTP
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
query := `
INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
now := time.Now()
otp.CreatedAt = now
otp.UpdatedAt = now
_, err := r.db.ExecContext(
ctx,
query,
otp.ID,
otp.UserID,
otp.Name,
otp.Issuer,
otp.Secret,
otp.Algorithm,
otp.Digits,
otp.Period,
otp.CreatedAt,
otp.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create otp: %w", err)
}
return nil
}
// Update updates an existing OTP
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
query := `
UPDATE otps
SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ?
WHERE id = ? AND user_id = ?
`
otp.UpdatedAt = time.Now()
result, err := r.db.ExecContext(
ctx,
query,
otp.Name,
otp.Issuer,
otp.Algorithm,
otp.Digits,
otp.Period,
otp.UpdatedAt,
otp.ID,
otp.UserID,
)
if err != nil {
return fmt.Errorf("failed to update otp: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return fmt.Errorf("otp not found or not owned by user")
}
return nil
}
// Delete deletes an OTP
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
query := `DELETE FROM otps WHERE id = ? AND user_id = ?`
result, err := r.db.ExecContext(ctx, query, id, userID)
if err != nil {
return fmt.Errorf("failed to delete otp: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return fmt.Errorf("otp not found or not owned by user")
}
return nil
}
// CountByUserID counts the number of OTPs for a user
func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) {
var count int
query := `SELECT COUNT(*) FROM otps WHERE user_id = ?`
err := r.db.GetContext(ctx, &count, query, userID)
if err != nil {
return 0, fmt.Errorf("failed to count otps: %w", err)
}
return count, nil
}
// Transaction executes a function within a transaction
func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
tx, err := r.db.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}

114
models/user.go Normal file
View file

@ -0,0 +1,114 @@
package models
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/jmoiron/sqlx"
)
// User represents a user in the system
type User struct {
ID string `db:"id" json:"id"`
OpenID string `db:"openid" json:"openid"`
SessionKey string `db:"session_key" json:"-"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// UserRepository handles user data operations
type UserRepository struct {
db *sqlx.DB
}
// NewUserRepository creates a new UserRepository
func NewUserRepository(db *sqlx.DB) *UserRepository {
return &UserRepository{db: db}
}
// FindByID finds a user by ID
func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
var user User
query := `SELECT * FROM users WHERE id = ?`
err := r.db.GetContext(ctx, &user, query, id)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found: %w", err)
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
return &user, nil
}
// FindByOpenID finds a user by OpenID
func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) {
var user User
query := `SELECT * FROM users WHERE openid = ?`
err := r.db.GetContext(ctx, &user, query, openID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // User not found, but not an error
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
return &user, nil
}
// Create creates a new user
func (r *UserRepository) Create(ctx context.Context, user *User) error {
query := `
INSERT INTO users (id, openid, session_key, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
`
now := time.Now()
user.CreatedAt = now
user.UpdatedAt = now
_, err := r.db.ExecContext(
ctx,
query,
user.ID,
user.OpenID,
user.SessionKey,
user.CreatedAt,
user.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// Update updates an existing user
func (r *UserRepository) Update(ctx context.Context, user *User) error {
query := `
UPDATE users
SET session_key = ?, updated_at = ?
WHERE id = ?
`
user.UpdatedAt = time.Now()
_, err := r.db.ExecContext(
ctx,
query,
user.SessionKey,
user.UpdatedAt,
user.ID,
)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// Delete deletes a user
func (r *UserRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM users WHERE id = ?`
_, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}

332
security/security.go Normal file
View file

@ -0,0 +1,332 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/crypto/argon2"
)
// SecurityService provides security functionality
type SecurityService struct {
config *Config
}
// Config represents security configuration
type Config struct {
// CSRF protection
CSRFTokenLength int
CSRFTokenExpiry time.Duration
CSRFCookieName string
CSRFHeaderName string
CSRFCookieSecure bool
CSRFCookieHTTPOnly bool
CSRFCookieSameSite http.SameSite
// Rate limiting
RateLimitRequests int
RateLimitWindow time.Duration
// Password hashing
Argon2Time uint32
Argon2Memory uint32
Argon2Threads uint8
Argon2KeyLen uint32
Argon2SaltLen uint32
}
// DefaultConfig returns the default security configuration
func DefaultConfig() *Config {
return &Config{
// CSRF protection
CSRFTokenLength: 32,
CSRFTokenExpiry: 24 * time.Hour,
CSRFCookieName: "csrf_token",
CSRFHeaderName: "X-CSRF-Token",
CSRFCookieSecure: true,
CSRFCookieHTTPOnly: true,
CSRFCookieSameSite: http.SameSiteStrictMode,
// Rate limiting
RateLimitRequests: 100,
RateLimitWindow: time.Minute,
// Password hashing
Argon2Time: 1,
Argon2Memory: 64 * 1024,
Argon2Threads: 4,
Argon2KeyLen: 32,
Argon2SaltLen: 16,
}
}
// NewSecurityService creates a new SecurityService
func NewSecurityService(config *Config) *SecurityService {
if config == nil {
config = DefaultConfig()
}
return &SecurityService{
config: config,
}
}
// GenerateCSRFToken generates a CSRF token
func (s *SecurityService) GenerateCSRFToken() (string, error) {
// Generate random bytes
bytes := make([]byte, s.config.CSRFTokenLength)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode as base64
token := base64.StdEncoding.EncodeToString(bytes)
return token, nil
}
// SetCSRFCookie sets a CSRF cookie
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
http.SetCookie(w, &http.Cookie{
Name: s.config.CSRFCookieName,
Value: token,
Path: "/",
Expires: time.Now().Add(s.config.CSRFTokenExpiry),
Secure: s.config.CSRFCookieSecure,
HttpOnly: s.config.CSRFCookieHTTPOnly,
SameSite: s.config.CSRFCookieSameSite,
})
}
// ValidateCSRFToken validates a CSRF token
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
// Get token from cookie
cookie, err := r.Cookie(s.config.CSRFCookieName)
if err != nil {
return false
}
cookieToken := cookie.Value
// Get token from header
headerToken := r.Header.Get(s.config.CSRFHeaderName)
// Compare tokens
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
}
// CSRFMiddleware creates a middleware that validates CSRF tokens
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip for GET, HEAD, OPTIONS, TRACE
if r.Method == http.MethodGet ||
r.Method == http.MethodHead ||
r.Method == http.MethodOptions ||
r.Method == http.MethodTrace {
next.ServeHTTP(w, r)
return
}
// Validate CSRF token
if !s.ValidateCSRFToken(r) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// RateLimiter represents a rate limiter
type RateLimiter struct {
requests map[string][]time.Time
config *Config
}
// NewRateLimiter creates a new RateLimiter
func NewRateLimiter(config *Config) *RateLimiter {
return &RateLimiter{
requests: make(map[string][]time.Time),
config: config,
}
}
// Allow checks if a request is allowed
func (r *RateLimiter) Allow(key string) bool {
now := time.Now()
windowStart := now.Add(-r.config.RateLimitWindow)
// Get requests for key
requests := r.requests[key]
// Filter out old requests
var newRequests []time.Time
for _, t := range requests {
if t.After(windowStart) {
newRequests = append(newRequests, t)
}
}
// Check if rate limit is exceeded
if len(newRequests) >= r.config.RateLimitRequests {
return false
}
// Add current request
newRequests = append(newRequests, now)
r.requests[key] = newRequests
return true
}
// RateLimitMiddleware creates a middleware that limits request rate
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP
ip := getClientIP(r)
// Check if request is allowed
if !limiter.Allow(ip) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// getClientIP gets the client IP address
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, use the first one
ips := strings.Split(xForwardedFor, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xRealIP := r.Header.Get("X-Real-IP")
if xRealIP != "" {
return xRealIP
}
// Use RemoteAddr
return r.RemoteAddr
}
// HashPassword hashes a password using Argon2
func (s *SecurityService) HashPassword(password string) (string, error) {
// Generate salt
salt := make([]byte, s.config.Argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
// Hash password
hash := argon2.IDKey(
[]byte(password),
salt,
s.config.Argon2Time,
s.config.Argon2Memory,
s.config.Argon2Threads,
s.config.Argon2KeyLen,
)
// Encode as base64
saltBase64 := base64.StdEncoding.EncodeToString(salt)
hashBase64 := base64.StdEncoding.EncodeToString(hash)
// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
return fmt.Sprintf(
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
s.config.Argon2Memory,
s.config.Argon2Time,
s.config.Argon2Threads,
saltBase64,
hashBase64,
), nil
}
// VerifyPassword verifies a password against a hash
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
// Parse encoded hash
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
// Extract parameters
if parts[1] != "argon2id" {
return false, fmt.Errorf("unsupported hash algorithm")
}
var memory uint32
var time uint32
var threads uint8
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
if err != nil {
return false, fmt.Errorf("failed to parse hash parameters: %w", err)
}
// Decode salt and hash
salt, err := base64.StdEncoding.DecodeString(parts[4])
if err != nil {
return false, fmt.Errorf("failed to decode salt: %w", err)
}
hash, err := base64.StdEncoding.DecodeString(parts[5])
if err != nil {
return false, fmt.Errorf("failed to decode hash: %w", err)
}
// Hash password with same parameters
newHash := argon2.IDKey(
[]byte(password),
salt,
time,
memory,
threads,
uint32(len(hash)),
)
// Compare hashes
return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
}
// SecureHeadersMiddleware adds security headers to responses
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
next.ServeHTTP(w, r)
})
}
// contextKey is a type for context keys
type contextKey string
// userIDKey is the key for user ID in context
const userIDKey = contextKey("user_id")
// WithUserID adds a user ID to the context
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}
// GetUserID gets the user ID from the context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userIDKey).(string)
return userID, ok
}

172
server/server.go Normal file
View file

@ -0,0 +1,172 @@
package server
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"otpm/config"
"otpm/middleware"
)
// Server represents the HTTP server
type Server struct {
server *http.Server
router *http.ServeMux
config *config.Config
}
// New creates a new server
func New(cfg *config.Config) *Server {
router := http.NewServeMux()
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: 120 * time.Second,
}
return &Server{
server: server,
router: router,
config: cfg,
}
}
// Start starts the server
func (s *Server) Start() error {
// Apply global middleware in correct order with enhanced error handling
var handler http.Handler = s.router
// Logger should be first to capture all request details
handler = middleware.Logger(handler)
// CORS next to handle pre-flight requests
handler = middleware.CORS(handler)
// Then Timeout to enforce request deadlines
handler = middleware.Timeout(s.config.Server.Timeout)(handler)
// Recover should be outermost to catch any panics
handler = middleware.Recover(handler)
s.server.Handler = handler
// Log server configuration at startup
log.Printf("Server configuration:\n"+
"Address: %s\n"+
"Read Timeout: %v\n"+
"Write Timeout: %v\n"+
"Idle Timeout: %v\n"+
"Request Timeout: %v",
s.server.Addr,
s.server.ReadTimeout,
s.server.WriteTimeout,
s.server.IdleTimeout,
s.config.Server.Timeout,
)
// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
log.Printf("Server starting on %s", s.server.Addr)
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
serverErr <- fmt.Errorf("server error: %w", err)
}
}()
// Wait for interrupt signal or server error
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-serverErr:
return err
case <-quit:
return s.Shutdown()
}
}
// Shutdown gracefully stops the server
func (s *Server) Shutdown() error {
log.Println("Shutting down server...")
ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("graceful shutdown failed: %w", err)
}
log.Println("Server stopped gracefully")
return nil
}
// Router returns the router
func (s *Server) Router() *http.ServeMux {
return s.router
}
// RegisterRoutes registers all routes
func (s *Server) RegisterRoutes(routes map[string]http.Handler) {
for pattern, handler := range routes {
s.router.Handle(pattern, handler)
}
}
// RegisterAuthRoutes registers routes that require authentication
func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) {
for pattern, handler := range routes {
// Apply authentication middleware
authHandler := middleware.Auth(s.config.JWT.Secret)(handler)
s.router.Handle(pattern, authHandler)
}
}
// RegisterHealthCheck registers an enhanced health check endpoint
func (s *Server) RegisterHealthCheck() {
s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"status": "ok",
"timestamp": time.Now().Format(time.RFC3339),
"version": "1.0.0", // Hardcoded version instead of from config
"system": map[string]interface{}{
"goroutines": runtime.NumGoroutine(),
"memory": getMemoryUsage(),
},
}
// Add database status if configured
if s.config.Database.DSN != "" { // Changed from URL to DSN to match config
dbStatus := "ok"
// Removed DB ping check since we don't have DB instance in config
response["database"] = dbStatus
}
middleware.SuccessResponse(w, response)
})
}
// getMemoryUsage returns current memory usage in MB
func getMemoryUsage() map[string]interface{} {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return map[string]interface{}{
"alloc_mb": bToMb(m.Alloc),
"total_alloc_mb": bToMb(m.TotalAlloc),
"sys_mb": bToMb(m.Sys),
"num_gc": m.NumGC,
}
}
func bToMb(b uint64) float64 {
return float64(b) / 1024 / 1024
}

230
services/auth.go Normal file
View file

@ -0,0 +1,230 @@
package services
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"otpm/config"
"otpm/models"
)
// WeChatCode2SessionResponse represents the response from WeChat code2session API
type WeChatCode2SessionResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// AuthService handles authentication related operations
type AuthService struct {
config *config.Config
userRepo *models.UserRepository
httpClient *http.Client
}
// NewAuthService creates a new AuthService
func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
return &AuthService{
config: cfg,
userRepo: userRepo,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// LoginWithWeChatCode handles WeChat login
func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
start := time.Now()
// Get OpenID and SessionKey from WeChat
sessionInfo, err := s.getWeChatSession(code)
if err != nil {
log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
return "", fmt.Errorf("failed to get WeChat session: %w", err)
}
log.Printf("WeChat session obtained for code %s (took %v)",
maskCode(code), time.Since(start))
// Find or create user
user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
if err != nil {
log.Printf("User lookup failed for OpenID %s: %v",
maskOpenID(sessionInfo.OpenID), err)
return "", fmt.Errorf("failed to find user: %w", err)
}
if user == nil {
// Create new user
user = &models.User{
ID: uuid.New().String(),
OpenID: sessionInfo.OpenID,
SessionKey: sessionInfo.SessionKey,
}
if err := s.userRepo.Create(ctx, user); err != nil {
log.Printf("User creation failed for OpenID %s: %v",
maskOpenID(sessionInfo.OpenID), err)
return "", fmt.Errorf("failed to create user: %w", err)
}
log.Printf("New user created with ID %s for OpenID %s",
user.ID, maskOpenID(sessionInfo.OpenID))
} else {
// Update session key
user.SessionKey = sessionInfo.SessionKey
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("User update failed for ID %s: %v", user.ID, err)
return "", fmt.Errorf("failed to update user: %w", err)
}
log.Printf("User %s session key updated", user.ID)
}
// Generate JWT token
token, err := s.generateToken(user)
if err != nil {
log.Printf("Token generation failed for user %s: %v", user.ID, err)
return "", fmt.Errorf("failed to generate token: %w", err)
}
log.Printf("WeChat login completed for user %s (total time %v)",
user.ID, time.Since(start))
return token, nil
}
// maskCode masks sensitive parts of WeChat code for logging
func maskCode(code string) string {
if len(code) < 8 {
return "****"
}
return code[:2] + "****" + code[len(code)-2:]
}
// maskOpenID masks sensitive parts of OpenID for logging
func maskOpenID(openID string) string {
if len(openID) < 8 {
return "****"
}
return openID[:2] + "****" + openID[len(openID)-2:]
}
// getWeChatSession calls WeChat's code2session API
func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
url := fmt.Sprintf(
"https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.config.WeChat.AppID,
s.config.WeChat.AppSecret,
code,
)
resp, err := s.httpClient.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to call WeChat API: %w", err)
}
defer resp.Body.Close()
var result WeChatCode2SessionResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
}
if result.ErrCode != 0 {
return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
}
return &result, nil
}
// generateToken generates a JWT token for a user
func (s *AuthService) generateToken(user *models.User) (string, error) {
now := time.Now()
claims := jwt.MapClaims{
"user_id": user.ID,
"exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
"iat": now.Unix(),
"iss": s.config.JWT.Issuer,
"aud": s.config.JWT.Audience,
"token_id": uuid.New().String(), // Unique token ID for tracking
}
// Use stronger signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
log.Printf("Token generated for user %s (expires at %v)",
user.ID, now.Add(s.config.JWT.ExpireDelta))
return signedToken, nil
}
// ValidateToken validates a JWT token with additional checks
func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.config.JWT.Secret), nil
})
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
switch {
case ve.Errors&jwt.ValidationErrorMalformed != 0:
return nil, fmt.Errorf("malformed token")
case ve.Errors&jwt.ValidationErrorExpired != 0:
return nil, fmt.Errorf("token expired")
case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
return nil, fmt.Errorf("token not active yet")
default:
return nil, fmt.Errorf("token validation error: %w", err)
}
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// Additional claims validation
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// Check issuer
if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
return nil, fmt.Errorf("invalid token issuer")
}
// Check audience
if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
return nil, fmt.Errorf("invalid token audience")
}
} else {
return nil, fmt.Errorf("invalid token claims")
}
return token, nil
}
// GetUserFromToken gets user information from a JWT token
func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, fmt.Errorf("user_id not found in token")
}
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
}
return user, nil
}

358
services/otp.go Normal file
View file

@ -0,0 +1,358 @@
package services
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"log"
"strings"
"time"
"otpm/models"
"github.com/google/uuid"
)
// OTPService handles OTP related operations
type OTPService struct {
otpRepo *models.OTPRepository
}
// NewOTPService creates a new OTPService
func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
return &OTPService{
otpRepo: otpRepo,
}
}
// CreateOTP creates a new OTP with performance monitoring and logging
func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
start := time.Now()
// Validate input
if err := s.validateOTPInput(input); err != nil {
log.Printf("OTP validation failed for user %s: %v", userID, err)
return nil, err
}
// Clean and standardize secret
secret := cleanSecret(input.Secret)
// Set defaults for optional fields
algorithm := strings.ToUpper(input.Algorithm)
if algorithm == "" {
algorithm = "SHA1"
}
digits := input.Digits
if digits == 0 {
digits = 6
}
period := input.Period
if period == 0 {
period = 30
}
// Create OTP
otp := &models.OTP{
ID: uuid.New().String(),
UserID: userID,
Name: input.Name,
Issuer: input.Issuer,
Secret: secret,
Algorithm: algorithm,
Digits: digits,
Period: period,
}
if err := s.otpRepo.Create(ctx, otp); err != nil {
log.Printf("Failed to create OTP for user %s: %v", userID, err)
return nil, fmt.Errorf("failed to create OTP: %w", err)
}
// Log successful creation (without exposing secret)
log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
return otp, nil
}
// GetOTP gets an OTP by ID
func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
return otp, nil
}
// ListOTPs lists all OTPs for a user
func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to list OTPs: %w", err)
}
return otps, nil
}
// UpdateOTP updates an OTP
func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
// Get existing OTP
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
// Update fields
if input.Name != "" {
otp.Name = input.Name
}
if input.Issuer != "" {
otp.Issuer = input.Issuer
}
if input.Algorithm != "" {
otp.Algorithm = strings.ToUpper(input.Algorithm)
}
if input.Digits > 0 {
otp.Digits = input.Digits
}
if input.Period > 0 {
otp.Period = input.Period
}
// Validate updated OTP
if err := s.validateOTPInput(models.OTPParams{
Name: otp.Name,
Issuer: otp.Issuer,
Secret: otp.Secret,
Algorithm: otp.Algorithm,
Digits: otp.Digits,
Period: otp.Period,
}); err != nil {
return nil, err
}
if err := s.otpRepo.Update(ctx, otp); err != nil {
return nil, fmt.Errorf("failed to update OTP: %w", err)
}
return otp, nil
}
// DeleteOTP deletes an OTP
func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
return fmt.Errorf("failed to delete OTP: %w", err)
}
return nil
}
// GenerateCode generates a TOTP code with enhanced logging and error handling
func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
start := time.Now()
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
return "", 0, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current time step
now := time.Now().Unix()
timeStep := now / int64(otp.Period)
// Generate code
code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
return "", 0, fmt.Errorf("failed to generate code: %w", err)
}
// Calculate remaining seconds
remainingSeconds := otp.Period - int(now%int64(otp.Period))
// Log successful generation (without actual code)
log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
id, userID, time.Since(start), remainingSeconds)
return code, remainingSeconds, nil
}
// VerifyCode verifies a TOTP code with enhanced security and logging
func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
start := time.Now()
// Basic input validation
if len(code) == 0 {
log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
return false, fmt.Errorf("code is required")
}
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s during verification: %v",
id, userID, err)
return false, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current and adjacent time steps
now := time.Now().Unix()
timeSteps := []int64{
(now - int64(otp.Period)) / int64(otp.Period),
now / int64(otp.Period),
(now + int64(otp.Period)) / int64(otp.Period),
}
// Check code against all time steps
for _, ts := range timeSteps {
expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Code generation failed for time step %d: %v", ts, err)
continue
}
if expectedCode == code {
// Log successful verification
log.Printf("Code verified successfully for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return true, nil
}
}
// Log failed verification attempt
log.Printf("Invalid code provided for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return false, nil
}
// validateOTPInput validates OTP input with detailed error messages
func (s *OTPService) validateOTPInput(input models.OTPParams) error {
if input.Name == "" {
return fmt.Errorf("name is required")
}
if len(input.Name) > 100 {
return fmt.Errorf("name is too long (maximum 100 characters)")
}
if input.Secret == "" {
return fmt.Errorf("secret is required")
}
if !isValidBase32(input.Secret) {
return fmt.Errorf("invalid secret format: must be a valid base32 string")
}
// Secret length check (after base32 decoding)
secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
if len(secretBytes) < 10 {
return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
}
if input.Algorithm != "" {
if !isValidAlgorithm(input.Algorithm) {
return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
}
}
if input.Digits != 0 {
if input.Digits < 6 || input.Digits > 8 {
return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
}
}
if input.Period != 0 {
if input.Period < 30 || input.Period > 60 {
return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
}
}
return nil
}
// Helper functions
func cleanSecret(secret string) string {
// Remove spaces and convert to upper case
secret = strings.TrimSpace(strings.ToUpper(secret))
// Remove any padding characters
return strings.TrimRight(secret, "=")
}
func isValidBase32(s string) bool {
// Try to decode the secret
_, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
return err == nil
}
func isValidAlgorithm(algorithm string) bool {
switch strings.ToUpper(algorithm) {
case "SHA1", "SHA256", "SHA512":
return true
default:
return false
}
}
func getHasher(algorithm string, key []byte) (hash.Hash, error) {
switch strings.ToUpper(algorithm) {
case "SHA1":
return hmac.New(sha1.New, key), nil
case "SHA256":
return hmac.New(sha256.New, key), nil
case "SHA512":
return hmac.New(sha512.New, key), nil
default:
return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
}
}
func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
// Decode secret
secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
if err != nil {
return "", fmt.Errorf("invalid secret: %w", err)
}
// Get initialized HMAC hasher with secret
hasher, err := getHasher(algorithm, secretBytes)
if err != nil {
return "", err
}
// Convert time step to bytes
timeBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
// Calculate HMAC
hasher.Write(timeBytes)
hash := hasher.Sum(nil)
// Get offset
offset := hash[len(hash)-1] & 0xf
// Generate 4-byte code
code := binary.BigEndian.Uint32(hash[offset : offset+4])
code = code & 0x7fffffff
// Get the specified number of digits
code = code % uint32(pow10(digits))
// Format code with leading zeros
return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
}
func pow10(n int) uint32 {
result := uint32(1)
for i := 0; i < n; i++ {
result *= 10
}
return result
}

View file

@ -1,20 +1,65 @@
package utils
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
"github.com/spf13/viper"
)
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
// 返回一个httprouter.Handle函数该函数接受http.ResponseWriter和*http.Request作为参数
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// 调用传入的http.Handler函数将http.ResponseWriter和*http.Request作为参数传递
h(w, r)
}
}
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
return
}
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
secret := viper.GetString("auth.secret")
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(secret), nil
})
if err != nil || !token.Valid {
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
return
}
type contextKey string
// 将 openid 存入上下文
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
next.ServeHTTP(w, r.WithContext(ctx))
}
}
// AesDecrypt 函数用于AES解密
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
//Base64解码
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
return origData, nil
}
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
func PKCS7UnPadding(plantText []byte) []byte {
// 获取密文的长度
length := len(plantText)
// 如果密文长度大于0
if length > 0 {
// 获取最后一个字节的值,即填充的位数
unPadding := int(plantText[length-1])
// 返回去除填充后的密文
return plantText[:(length - unPadding)]
}
// 如果密文长度为0则返回原密文
return plantText
}

159
validator/validator.go Normal file
View file

@ -0,0 +1,159 @@
package validator
import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var (
validate *validator.Validate
// 自定义验证规则
customValidations = map[string]validator.Func{
"otpsecret": validateOTPSecret,
"password": validatePassword,
}
)
func init() {
validate = validator.New()
// 注册自定义验证规则
for tag, fn := range customValidations {
if err := validate.RegisterValidation(tag, fn); err != nil {
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
}
}
// 使用json tag作为字段名
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// ValidateRequest validates a request body against a struct
func ValidateRequest(r *http.Request, v interface{}) error {
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := validate.Struct(v); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
return NewValidationError(validationErrors)
}
return fmt.Errorf("validation error: %w", err)
}
return nil
}
// ValidationError represents a validation error
type ValidationError struct {
Fields map[string]string `json:"fields"`
}
// Error implements the error interface
func (e *ValidationError) Error() string {
var errors []string
for field, msg := range e.Fields {
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
}
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
}
// NewValidationError creates a new ValidationError from validator.ValidationErrors
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
fields := make(map[string]string)
for _, err := range errors {
fields[err.Field()] = getErrorMessage(err)
}
return &ValidationError{Fields: fields}
}
// getErrorMessage returns a human-readable error message for a validation error
func getErrorMessage(err validator.FieldError) string {
switch err.Tag() {
case "required":
return "This field is required"
case "email":
return "Invalid email address"
case "min":
return fmt.Sprintf("Must be at least %s characters long", err.Param())
case "max":
return fmt.Sprintf("Must be at most %s characters long", err.Param())
case "otpsecret":
return "Invalid OTP secret format"
case "password":
return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character"
default:
return fmt.Sprintf("Failed validation on tag: %s", err.Tag())
}
}
// Custom validation functions
// validateOTPSecret validates an OTP secret
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// OTP secret should be base32 encoded
matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret)
return matched
}
// validatePassword validates a password
func validatePassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
// At least 8 characters long
if len(password) < 8 {
return false
}
var (
hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password)
hasLower = regexp.MustCompile(`[a-z]`).MatchString(password)
hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password)
hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
)
return hasUpper && hasLower && hasNumber && hasSpecial
}
// Request validation structs
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required"`
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name" validate:"required,min=1,max=100"`
Issuer string `json:"issuer" validate:"required,min=1,max=100"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"required,oneof=6 8"`
Period int `json:"period" validate:"required,oneof=30 60"`
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name" validate:"omitempty,min=1,max=100"`
Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
Period int `json:"period" validate:"omitempty,oneof=30 60"`
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}