Compare commits
7 commits
Author | SHA1 | Date | |
---|---|---|---|
|
44500afd3f | ||
|
bcd986e3f7 | ||
|
a45ddf13d5 | ||
|
a6461a9a0e | ||
|
2d3698716e | ||
|
25c5f530b8 | ||
|
079542e431 |
49 changed files with 6260 additions and 284 deletions
149
api/response.go
Normal file
149
api/response.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Common error codes
|
||||
const (
|
||||
CodeSuccess = 0
|
||||
CodeInvalidParams = 400
|
||||
CodeUnauthorized = 401
|
||||
CodeForbidden = 403
|
||||
CodeNotFound = 404
|
||||
CodeInternalError = 500
|
||||
CodeServiceUnavail = 503
|
||||
)
|
||||
|
||||
// Error represents an API error
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// NewError creates a new API error
|
||||
func NewError(code int, message string) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents a standard API response
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseWriter wraps common response writing functions
|
||||
type ResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// NewResponseWriter creates a new ResponseWriter
|
||||
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
|
||||
return &ResponseWriter{w}
|
||||
}
|
||||
|
||||
// WriteJSON writes a JSON response
|
||||
func (w *ResponseWriter) WriteJSON(code int, data interface{}) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
return json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// WriteSuccess writes a success response
|
||||
func (w *ResponseWriter) WriteSuccess(data interface{}) error {
|
||||
return w.WriteJSON(http.StatusOK, Response{
|
||||
Code: CodeSuccess,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// WriteError writes an error response
|
||||
func (w *ResponseWriter) WriteError(err error) error {
|
||||
var apiErr *Error
|
||||
if errors.As(err, &apiErr) {
|
||||
return w.WriteJSON(getHTTPStatus(apiErr.Code), Response{
|
||||
Code: apiErr.Code,
|
||||
Message: apiErr.Message,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle unknown errors
|
||||
return w.WriteJSON(http.StatusInternalServerError, Response{
|
||||
Code: CodeInternalError,
|
||||
Message: "Internal Server Error",
|
||||
})
|
||||
}
|
||||
|
||||
// WriteErrorWithCode writes an error response with a specific code
|
||||
func (w *ResponseWriter) WriteErrorWithCode(code int, message string) error {
|
||||
return w.WriteJSON(getHTTPStatus(code), Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// getHTTPStatus maps API error codes to HTTP status codes
|
||||
func getHTTPStatus(code int) int {
|
||||
switch code {
|
||||
case CodeSuccess:
|
||||
return http.StatusOK
|
||||
case CodeInvalidParams:
|
||||
return http.StatusBadRequest
|
||||
case CodeUnauthorized:
|
||||
return http.StatusUnauthorized
|
||||
case CodeForbidden:
|
||||
return http.StatusForbidden
|
||||
case CodeNotFound:
|
||||
return http.StatusNotFound
|
||||
case CodeServiceUnavail:
|
||||
return http.StatusServiceUnavailable
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrInvalidParams = NewError(CodeInvalidParams, "Invalid parameters")
|
||||
ErrUnauthorized = NewError(CodeUnauthorized, "Unauthorized")
|
||||
ErrForbidden = NewError(CodeForbidden, "Forbidden")
|
||||
ErrNotFound = NewError(CodeNotFound, "Resource not found")
|
||||
ErrInternalError = NewError(CodeInternalError, "Internal server error")
|
||||
ErrServiceUnavail = NewError(CodeServiceUnavail, "Service unavailable")
|
||||
)
|
||||
|
||||
// ValidationError creates an error for invalid parameters
|
||||
func ValidationError(message string) *Error {
|
||||
return NewError(CodeInvalidParams, message)
|
||||
}
|
||||
|
||||
// NotFoundError creates an error for not found resources
|
||||
func NotFoundError(resource string) *Error {
|
||||
return NewError(CodeNotFound, fmt.Sprintf("%s not found", resource))
|
||||
}
|
||||
|
||||
// ForbiddenError creates an error for forbidden actions
|
||||
func ForbiddenError(message string) *Error {
|
||||
return NewError(CodeForbidden, message)
|
||||
}
|
||||
|
||||
// InternalError creates an error for internal server errors
|
||||
func InternalError(err error) *Error {
|
||||
if err == nil {
|
||||
return ErrInternalError
|
||||
}
|
||||
return NewError(CodeInternalError, err.Error())
|
||||
}
|
206
cache/cache.go
vendored
Normal file
206
cache/cache.go
vendored
Normal file
|
@ -0,0 +1,206 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Item represents a cache item
|
||||
type Item struct {
|
||||
Value []byte
|
||||
Expiration int64
|
||||
}
|
||||
|
||||
// Expired returns true if the item has expired
|
||||
func (item Item) Expired() bool {
|
||||
if item.Expiration == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixNano() > item.Expiration
|
||||
}
|
||||
|
||||
// Cache represents an in-memory cache
|
||||
type Cache struct {
|
||||
items map[string]Item
|
||||
mu sync.RWMutex
|
||||
defaultExpiration time.Duration
|
||||
cleanupInterval time.Duration
|
||||
stopCleanup chan bool
|
||||
}
|
||||
|
||||
// New creates a new cache with the given default expiration and cleanup interval
|
||||
func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
|
||||
cache := &Cache{
|
||||
items: make(map[string]Item),
|
||||
defaultExpiration: defaultExpiration,
|
||||
cleanupInterval: cleanupInterval,
|
||||
stopCleanup: make(chan bool),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine if cleanup interval > 0
|
||||
if cleanupInterval > 0 {
|
||||
go cache.startCleanup()
|
||||
}
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// startCleanup starts the cleanup process
|
||||
func (c *Cache) startCleanup() {
|
||||
ticker := time.NewTicker(c.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.DeleteExpired()
|
||||
case <-c.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache with the given key and expiration
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) error {
|
||||
// Convert value to bytes
|
||||
var valueBytes []byte
|
||||
var err error
|
||||
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
valueBytes = v
|
||||
case string:
|
||||
valueBytes = []byte(v)
|
||||
default:
|
||||
valueBytes, err = json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate expiration
|
||||
var exp int64
|
||||
if expiration == 0 {
|
||||
if c.defaultExpiration > 0 {
|
||||
exp = time.Now().Add(c.defaultExpiration).UnixNano()
|
||||
}
|
||||
} else if expiration > 0 {
|
||||
exp = time.Now().Add(expiration).UnixNano()
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.items[key] = Item{
|
||||
Value: valueBytes,
|
||||
Expiration: exp,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get gets an item from the cache
|
||||
func (c *Cache) Get(key string, value interface{}) (bool, error) {
|
||||
c.mu.RLock()
|
||||
item, found := c.items[key]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !found {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if item has expired
|
||||
if item.Expired() {
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.mu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Unmarshal value
|
||||
switch v := value.(type) {
|
||||
case *[]byte:
|
||||
*v = item.Value
|
||||
case *string:
|
||||
*v = string(item.Value)
|
||||
default:
|
||||
if err := json.Unmarshal(item.Value, value); err != nil {
|
||||
return true, fmt.Errorf("failed to unmarshal value: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Delete deletes an item from the cache
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// DeleteExpired deletes all expired items from the cache
|
||||
func (c *Cache) DeleteExpired() {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
c.mu.Lock()
|
||||
for k, v := range c.items {
|
||||
if v.Expiration > 0 && now > v.Expiration {
|
||||
delete(c.items, k)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Clear deletes all items from the cache
|
||||
func (c *Cache) Clear() {
|
||||
c.mu.Lock()
|
||||
c.items = make(map[string]Item)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Close stops the cleanup goroutine
|
||||
func (c *Cache) Close() {
|
||||
if c.cleanupInterval > 0 {
|
||||
c.stopCleanup <- true
|
||||
}
|
||||
}
|
||||
|
||||
// CacheService provides caching functionality
|
||||
type CacheService struct {
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewCacheService creates a new CacheService
|
||||
func NewCacheService(defaultExpiration, cleanupInterval time.Duration) *CacheService {
|
||||
return &CacheService{
|
||||
cache: New(defaultExpiration, cleanupInterval),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache
|
||||
func (s *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||
return s.cache.Set(key, value, expiration)
|
||||
}
|
||||
|
||||
// Get gets an item from the cache
|
||||
func (s *CacheService) Get(ctx context.Context, key string, value interface{}) (bool, error) {
|
||||
return s.cache.Get(key, value)
|
||||
}
|
||||
|
||||
// Delete deletes an item from the cache
|
||||
func (s *CacheService) Delete(ctx context.Context, key string) {
|
||||
s.cache.Delete(key)
|
||||
}
|
||||
|
||||
// Clear deletes all items from the cache
|
||||
func (s *CacheService) Clear(ctx context.Context) {
|
||||
s.cache.Clear()
|
||||
}
|
||||
|
||||
// Close closes the cache
|
||||
func (s *CacheService) Close() {
|
||||
s.cache.Close()
|
||||
}
|
177
cmd/root.go
177
cmd/root.go
|
@ -1,83 +1,140 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"otpm/config"
|
||||
"otpm/database"
|
||||
"otpm/handlers"
|
||||
"otpm/utils"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"otpm/models"
|
||||
"otpm/server"
|
||||
"otpm/services"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "otpm",
|
||||
Short: "otp backend for microapp on wechat",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
startApp()
|
||||
},
|
||||
}
|
||||
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
cobra.OnInitialize(initConfig)
|
||||
// Set config file with multi-environment support
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)")
|
||||
rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)")
|
||||
rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string")
|
||||
rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on")
|
||||
|
||||
viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver"))
|
||||
viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn"))
|
||||
viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
|
||||
}
|
||||
|
||||
func initConfig() {
|
||||
if cfgFile := viper.GetString("config"); cfgFile != "" {
|
||||
viper.SetConfigFile(cfgFile)
|
||||
} else {
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
// Set environment specific config (e.g. config.production.yaml)
|
||||
env := os.Getenv("OTPM_ENV")
|
||||
if env != "" {
|
||||
viper.SetConfigName(fmt.Sprintf("config.%s", env))
|
||||
}
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.Fatalf("Error reading config file: %v", err)
|
||||
}
|
||||
// Set default values
|
||||
viper.SetDefault("server.port", "8080")
|
||||
viper.SetDefault("server.timeout.read", "15s")
|
||||
viper.SetDefault("server.timeout.write", "15s")
|
||||
viper.SetDefault("server.timeout.idle", "60s")
|
||||
viper.SetDefault("database.max_open_conns", 25)
|
||||
viper.SetDefault("database.max_idle_conns", 5)
|
||||
viper.SetDefault("database.conn_max_lifetime", "5m")
|
||||
|
||||
// Set environment variable prefix
|
||||
viper.SetEnvPrefix("OTPM")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// Bind environment variables
|
||||
viper.BindEnv("database.url", "OTPM_DB_URL")
|
||||
viper.BindEnv("database.password", "OTPM_DB_PASSWORD")
|
||||
}
|
||||
|
||||
func initApp(db *sqlx.DB) {
|
||||
if err := database.MigrateDB(db); err != nil {
|
||||
log.Fatalf("Error migrating the database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startApp() {
|
||||
port := viper.GetInt("port")
|
||||
db, err := database.InitDB()
|
||||
// Execute is the entry point for the application
|
||||
func Execute() error {
|
||||
// Load configuration
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("Error connecting to the database: %v", err)
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
// Create context with cancellation
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup signal handling
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal: %v", sig)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Initialize database
|
||||
db, err := database.New(&cfg.Database)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
initApp(db)
|
||||
handler := &handlers.Handler{DB: db}
|
||||
|
||||
router := httprouter.New()
|
||||
router.POST("/login", utils.AdaptHandler(handler.Login))
|
||||
router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
|
||||
router.GET("/get", utils.AdaptHandler(handler.GetOtp))
|
||||
// Run database migrations
|
||||
if err := database.MigrateWithContext(ctx, db.DB, cfg.Database.SkipMigration); err != nil {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Starting server on :8080")
|
||||
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), router))
|
||||
// Initialize repositories
|
||||
userRepo := models.NewUserRepository(db.DB)
|
||||
otpRepo := models.NewOTPRepository(db.DB)
|
||||
|
||||
// Initialize services
|
||||
authService := services.NewAuthService(cfg, userRepo)
|
||||
otpService := services.NewOTPService(otpRepo)
|
||||
|
||||
// Initialize handlers
|
||||
authHandler := handlers.NewAuthHandler(authService)
|
||||
otpHandler := handlers.NewOTPHandler(otpService)
|
||||
|
||||
// Create and configure server
|
||||
srv := server.New(cfg)
|
||||
|
||||
// Register health check endpoint
|
||||
srv.RegisterHealthCheck()
|
||||
|
||||
// Register public routes with type conversion
|
||||
authRoutes := make(map[string]http.Handler)
|
||||
for path, handler := range authHandler.Routes() {
|
||||
authRoutes[path] = http.HandlerFunc(handler)
|
||||
}
|
||||
srv.RegisterRoutes(authRoutes)
|
||||
|
||||
// Register authenticated routes with type conversion
|
||||
otpRoutes := make(map[string]http.Handler)
|
||||
for path, handler := range otpHandler.Routes() {
|
||||
otpRoutes[path] = http.HandlerFunc(handler)
|
||||
}
|
||||
srv.RegisterAuthRoutes(otpRoutes)
|
||||
|
||||
// Start server in goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
log.Printf("Starting server on %s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||
if err := srv.Start(); err != nil {
|
||||
serverErr <- fmt.Errorf("server error: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for shutdown signal or server error
|
||||
select {
|
||||
case err := <-serverErr:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
// Graceful shutdown with timeout
|
||||
log.Println("Shutting down server...")
|
||||
if err := srv.Shutdown(); err != nil {
|
||||
return fmt.Errorf("server shutdown error: %w", err)
|
||||
}
|
||||
log.Println("Server stopped gracefully")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
23
config.yaml
23
config.yaml
|
@ -1,8 +1,23 @@
|
|||
server:
|
||||
port: 8080
|
||||
read_timeout: 15s
|
||||
write_timeout: 15s
|
||||
shutdown_timeout: 5s
|
||||
|
||||
database:
|
||||
driver: sqlite
|
||||
driver: sqlite3
|
||||
dsn: otpm.sqlite
|
||||
port: 8080
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 25
|
||||
max_lifetime: 5m
|
||||
skip_migration: false
|
||||
|
||||
jwt:
|
||||
secret: "your-jwt-secret-key-change-this-in-production"
|
||||
expire_delta: 24h
|
||||
refresh_delta: 168h
|
||||
signing_method: HS256
|
||||
|
||||
wechat:
|
||||
appid: "wx57d1033974eb5250"
|
||||
secret: "be494c2a81df685a40b9a74e1736b15d"
|
||||
app_id: "your-wechat-app-id"
|
||||
app_secret: "your-wechat-app-secret"
|
128
config/config.go
Normal file
128
config/config.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config holds all configuration for the application
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
WeChat WeChatConfig `mapstructure:"wechat"`
|
||||
}
|
||||
|
||||
// ServerConfig holds all server related configuration
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout"`
|
||||
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
|
||||
Timeout time.Duration `mapstructure:"timeout"` // Request processing timeout
|
||||
}
|
||||
|
||||
// DatabaseConfig holds all database related configuration
|
||||
type DatabaseConfig struct {
|
||||
Driver string `mapstructure:"driver"`
|
||||
DSN string `mapstructure:"dsn"`
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
MaxLifetime time.Duration `mapstructure:"max_lifetime"`
|
||||
SkipMigration bool `mapstructure:"skip_migration"`
|
||||
}
|
||||
|
||||
// JWTConfig holds all JWT related configuration
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireDelta time.Duration `mapstructure:"expire_delta"`
|
||||
RefreshDelta time.Duration `mapstructure:"refresh_delta"`
|
||||
SigningMethod string `mapstructure:"signing_method"`
|
||||
Issuer string `mapstructure:"issuer"`
|
||||
Audience string `mapstructure:"audience"`
|
||||
}
|
||||
|
||||
// WeChatConfig holds all WeChat related configuration
|
||||
type WeChatConfig struct {
|
||||
AppID string `mapstructure:"app_id"`
|
||||
AppSecret string `mapstructure:"app_secret"`
|
||||
}
|
||||
|
||||
// LoadConfig loads the configuration from file and environment variables
|
||||
func LoadConfig() (*Config, error) {
|
||||
// Set default values
|
||||
setDefaults()
|
||||
|
||||
// Read config file
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Validate config
|
||||
if err := validateConfig(&config); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// setDefaults sets default values for configuration
|
||||
func setDefaults() {
|
||||
// Server defaults
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.read_timeout", "15s")
|
||||
viper.SetDefault("server.write_timeout", "15s")
|
||||
viper.SetDefault("server.shutdown_timeout", "5s")
|
||||
viper.SetDefault("server.timeout", "30s") // Default request processing timeout
|
||||
|
||||
// Database defaults
|
||||
viper.SetDefault("database.driver", "sqlite3")
|
||||
viper.SetDefault("database.max_open_conns", 25)
|
||||
viper.SetDefault("database.max_idle_conns", 25)
|
||||
viper.SetDefault("database.max_lifetime", "5m")
|
||||
viper.SetDefault("database.skip_migration", false)
|
||||
|
||||
// JWT defaults
|
||||
viper.SetDefault("jwt.expire_delta", "24h")
|
||||
viper.SetDefault("jwt.refresh_delta", "168h") // 7 days
|
||||
viper.SetDefault("jwt.signing_method", "HS256")
|
||||
viper.SetDefault("jwt.issuer", "otpm")
|
||||
viper.SetDefault("jwt.audience", "otpm-client")
|
||||
}
|
||||
|
||||
// validateConfig validates the configuration
|
||||
func validateConfig(config *Config) error {
|
||||
if config.Server.Port < 1 || config.Server.Port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %d", config.Server.Port)
|
||||
}
|
||||
|
||||
if config.Database.Driver == "" {
|
||||
return fmt.Errorf("database driver is required")
|
||||
}
|
||||
|
||||
if config.Database.DSN == "" {
|
||||
return fmt.Errorf("database DSN is required")
|
||||
}
|
||||
|
||||
if config.JWT.Secret == "" {
|
||||
return fmt.Errorf("JWT secret is required")
|
||||
}
|
||||
|
||||
if config.WeChat.AppID == "" {
|
||||
return fmt.Errorf("WeChat AppID is required")
|
||||
}
|
||||
|
||||
if config.WeChat.AppSecret == "" {
|
||||
return fmt.Errorf("WeChat AppSecret is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"log"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/spf13/viper"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed init/users.sql
|
||||
userTable string
|
||||
//go:embed init/otp.sql
|
||||
otpTable string
|
||||
)
|
||||
|
||||
func InitDB() (*sqlx.DB, error) {
|
||||
driver := viper.GetString("database.driver")
|
||||
dsn := viper.GetString("database.dsn")
|
||||
|
||||
db, err := sqlx.Open(driver, dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Println("Connected to database!")
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func MigrateDB(db *sqlx.DB) error {
|
||||
_, err := db.Exec(userTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(otpTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
196
database/db.go
Normal file
196
database/db.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"otpm/config"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// DB wraps sqlx.DB to provide additional functionality
|
||||
type DB struct {
|
||||
*sqlx.DB
|
||||
}
|
||||
|
||||
// New creates a new database connection
|
||||
func New(cfg *config.DatabaseConfig) (*DB, error) {
|
||||
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool with optimized settings
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
|
||||
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
|
||||
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
|
||||
|
||||
// Verify connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully connected to database")
|
||||
return &DB{db}, nil
|
||||
}
|
||||
|
||||
// WithTx executes a function within a transaction with retry logic
|
||||
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
|
||||
const maxRetries = 3
|
||||
var lastErr error
|
||||
|
||||
// Default transaction options
|
||||
opts := &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
start := time.Now()
|
||||
|
||||
tx, err := db.BeginTxx(ctx, opts)
|
||||
if err != nil {
|
||||
if isRetryableError(err) && attempt < maxRetries {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
|
||||
}
|
||||
|
||||
if isRetryableError(err) && attempt < maxRetries {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
|
||||
}
|
||||
|
||||
// Log long-running transactions
|
||||
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
|
||||
log.Printf("Transaction completed in %v", elapsed)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
if isRetryableError(err) && attempt < maxRetries {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error is likely to succeed on retry
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "deadlock") ||
|
||||
strings.Contains(errStr, "timeout") ||
|
||||
strings.Contains(errStr, "try again") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "busy") ||
|
||||
strings.Contains(errStr, "locked")
|
||||
}
|
||||
|
||||
// ExecContext executes a query with adaptive timeout
|
||||
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
|
||||
// Set timeout based on query complexity
|
||||
timeout := 5 * time.Second
|
||||
if strings.Contains(strings.ToLower(query), "insert") ||
|
||||
strings.Contains(strings.ToLower(query), "update") ||
|
||||
strings.Contains(strings.ToLower(query), "delete") {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := db.DB.ExecContext(ctx, query, args...)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Log slow queries
|
||||
if elapsed > timeout/2 {
|
||||
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryRowContext executes a query that returns a single row with timeout
|
||||
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return db.DB.QueryRowxContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// QueryContext executes a query that returns multiple rows with adaptive timeout
|
||||
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
|
||||
// Set timeout based on query complexity
|
||||
timeout := 5 * time.Second
|
||||
if strings.Contains(strings.ToLower(query), "join") ||
|
||||
strings.Contains(strings.ToLower(query), "group by") ||
|
||||
strings.Contains(strings.ToLower(query), "order by") {
|
||||
timeout = 15 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
rows, err := db.DB.QueryxContext(ctx, query, args...)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Log slow queries
|
||||
if elapsed > timeout/2 {
|
||||
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
if err := db.DB.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close database connection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
CREATE TABLE IF NOT EXISTS otp (
|
||||
id SERIAL PRIMARY KEY,
|
||||
openid VARCHAR(255),
|
||||
num INTEGER,
|
||||
token VARCHAR(255)
|
||||
id SERIAL PRIMARY KEY,
|
||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||
token VARCHAR(255),
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
|
@ -2,4 +2,5 @@ CREATE TABLE IF NOT EXISTS users (
|
|||
id SERIAL PRIMARY KEY,
|
||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||
session_key VARCHAR(255) UNIQUE NOT NULL
|
||||
);
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_users_openid ON users(openid);
|
160
database/migration.go
Normal file
160
database/migration.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed init/users.sql
|
||||
userTable string
|
||||
//go:embed init/otp.sql
|
||||
otpTable string
|
||||
)
|
||||
|
||||
// Migration represents a database migration
|
||||
type Migration struct {
|
||||
Name string
|
||||
SQL string
|
||||
Version int
|
||||
}
|
||||
|
||||
// Migrations is a list of all migrations
|
||||
var Migrations = []Migration{
|
||||
{
|
||||
Name: "Create users table",
|
||||
SQL: userTable,
|
||||
Version: 1,
|
||||
},
|
||||
{
|
||||
Name: "Create OTP table",
|
||||
SQL: otpTable,
|
||||
Version: 2,
|
||||
},
|
||||
}
|
||||
|
||||
// MigrationRecord represents a record in the migrations table
|
||||
type MigrationRecord struct {
|
||||
ID int `db:"id"`
|
||||
Version int `db:"version"`
|
||||
Name string `db:"name"`
|
||||
AppliedAt time.Time `db:"applied_at"`
|
||||
}
|
||||
|
||||
// ensureMigrationsTable ensures that the migrations table exists
|
||||
func ensureMigrationsTable(db *sqlx.DB) error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
version INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.Exec(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrations table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAppliedMigrations gets all applied migrations
|
||||
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
|
||||
var records []MigrationRecord
|
||||
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
|
||||
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[int]MigrationRecord)
|
||||
for _, record := range records {
|
||||
result[record.Version] = record
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Migrate runs all pending migrations
|
||||
func Migrate(db *sqlx.DB, skipMigration bool) error {
|
||||
if skipMigration {
|
||||
log.Println("Skipping database migration as configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure migrations table exists
|
||||
if err := ensureMigrationsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get applied migrations
|
||||
appliedMigrations, err := getAppliedMigrations(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply pending migrations
|
||||
for _, migration := range Migrations {
|
||||
if _, ok := appliedMigrations[migration.Version]; ok {
|
||||
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
|
||||
|
||||
// Start a transaction for this migration
|
||||
tx, err := db.Beginx()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
// Execute migration
|
||||
if _, err := tx.Exec(migration.SQL); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
// Record migration
|
||||
if _, err := tx.Exec(
|
||||
"INSERT INTO migrations (version, name) VALUES (?, ?)",
|
||||
migration.Version, migration.Name,
|
||||
); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateWithContext runs all pending migrations with context
|
||||
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
|
||||
// Create a channel to signal completion
|
||||
done := make(chan error, 1)
|
||||
|
||||
// Run migration in a goroutine
|
||||
go func() {
|
||||
done <- Migrate(db, skipMigration)
|
||||
}()
|
||||
|
||||
// Wait for migration to complete or context to be canceled
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("migration canceled: %w", ctx.Err())
|
||||
}
|
||||
}
|
663
docs/swagger.go
Normal file
663
docs/swagger.go
Normal file
|
@ -0,0 +1,663 @@
|
|||
// Package docs provides API documentation using Swagger/OpenAPI
|
||||
package docs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SwaggerInfo holds the API information
|
||||
var SwaggerInfo = struct {
|
||||
Title string
|
||||
Description string
|
||||
Version string
|
||||
Host string
|
||||
BasePath string
|
||||
Schemes []string
|
||||
}{
|
||||
Title: "OTPM API",
|
||||
Description: "API for One-Time Password Manager",
|
||||
Version: "1.0.0",
|
||||
Host: "localhost:8080",
|
||||
BasePath: "/",
|
||||
Schemes: []string{"http", "https"},
|
||||
}
|
||||
|
||||
// SwaggerJSON returns the Swagger JSON
|
||||
func SwaggerJSON() []byte {
|
||||
swagger := map[string]interface{}{
|
||||
"swagger": "2.0",
|
||||
"info": map[string]interface{}{
|
||||
"title": SwaggerInfo.Title,
|
||||
"description": SwaggerInfo.Description,
|
||||
"version": SwaggerInfo.Version,
|
||||
},
|
||||
"host": SwaggerInfo.Host,
|
||||
"basePath": SwaggerInfo.BasePath,
|
||||
"schemes": SwaggerInfo.Schemes,
|
||||
"paths": getPaths(),
|
||||
"definitions": map[string]interface{}{
|
||||
"LoginRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"code": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "WeChat authorization code",
|
||||
},
|
||||
},
|
||||
"required": []string{"code"},
|
||||
},
|
||||
"LoginResponse": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"token": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "JWT token",
|
||||
},
|
||||
"openid": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "WeChat OpenID",
|
||||
},
|
||||
},
|
||||
},
|
||||
"CreateOTPRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP name",
|
||||
},
|
||||
"issuer": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP issuer",
|
||||
},
|
||||
"secret": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP secret",
|
||||
},
|
||||
"algorithm": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP algorithm",
|
||||
"enum": []string{"SHA1", "SHA256", "SHA512"},
|
||||
},
|
||||
"digits": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "OTP digits",
|
||||
"enum": []int{6, 8},
|
||||
},
|
||||
"period": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "OTP period in seconds",
|
||||
"enum": []int{30, 60},
|
||||
},
|
||||
},
|
||||
"required": []string{"name", "issuer", "secret", "algorithm", "digits", "period"},
|
||||
},
|
||||
"OTP": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP ID",
|
||||
},
|
||||
"user_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "User ID",
|
||||
},
|
||||
"name": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP name",
|
||||
},
|
||||
"issuer": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP issuer",
|
||||
},
|
||||
"algorithm": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP algorithm",
|
||||
},
|
||||
"digits": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "OTP digits",
|
||||
},
|
||||
"period": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "OTP period in seconds",
|
||||
},
|
||||
"created_at": map[string]interface{}{
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Creation time",
|
||||
},
|
||||
"updated_at": map[string]interface{}{
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Last update time",
|
||||
},
|
||||
},
|
||||
},
|
||||
"OTPCodeResponse": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"code": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP code",
|
||||
},
|
||||
"expires_in": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Seconds until expiration",
|
||||
},
|
||||
},
|
||||
},
|
||||
"VerifyOTPRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"code": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP code to verify",
|
||||
},
|
||||
},
|
||||
"required": []string{"code"},
|
||||
},
|
||||
"VerifyOTPResponse": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"valid": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "Whether the code is valid",
|
||||
},
|
||||
},
|
||||
},
|
||||
"UpdateOTPRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP name",
|
||||
},
|
||||
"issuer": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP issuer",
|
||||
},
|
||||
"algorithm": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "OTP algorithm",
|
||||
"enum": []string{"SHA1", "SHA256", "SHA512"},
|
||||
},
|
||||
"digits": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "OTP digits",
|
||||
"enum": []int{6, 8},
|
||||
},
|
||||
"period": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "OTP period in seconds",
|
||||
"enum": []int{30, 60},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ErrorResponse": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"code": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Error code",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Error message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"securityDefinitions": map[string]interface{}{
|
||||
"Bearer": map[string]interface{}{
|
||||
"type": "apiKey",
|
||||
"name": "Authorization",
|
||||
"in": "header",
|
||||
"description": "JWT token with Bearer prefix",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, _ := json.MarshalIndent(swagger, "", " ")
|
||||
return data
|
||||
}
|
||||
|
||||
// getPaths returns the API paths
|
||||
func getPaths() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"/login": map[string]interface{}{
|
||||
"post": map[string]interface{}{
|
||||
"summary": "Login with WeChat",
|
||||
"description": "Login with WeChat authorization code",
|
||||
"tags": []string{"auth"},
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "body",
|
||||
"in": "body",
|
||||
"description": "Login request",
|
||||
"required": true,
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/LoginRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "Successful login",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/LoginResponse",
|
||||
},
|
||||
},
|
||||
"400": map[string]interface{}{
|
||||
"description": "Invalid request",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/verify-token": map[string]interface{}{
|
||||
"post": map[string]interface{}{
|
||||
"summary": "Verify token",
|
||||
"description": "Verify JWT token",
|
||||
"tags": []string{"auth"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "Token is valid",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"valid": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "Whether the token is valid",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/otp": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "List OTPs",
|
||||
"description": "List all OTPs for the authenticated user",
|
||||
"tags": []string{"otp"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "List of OTPs",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{
|
||||
"$ref": "#/definitions/OTP",
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"summary": "Create OTP",
|
||||
"description": "Create a new OTP",
|
||||
"tags": []string{"otp"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "body",
|
||||
"in": "body",
|
||||
"description": "OTP creation request",
|
||||
"required": true,
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/CreateOTPRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "OTP created",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/OTP",
|
||||
},
|
||||
},
|
||||
"400": map[string]interface{}{
|
||||
"description": "Invalid request",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/otp/{id}": map[string]interface{}{
|
||||
"put": map[string]interface{}{
|
||||
"summary": "Update OTP",
|
||||
"description": "Update an existing OTP",
|
||||
"tags": []string{"otp"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"description": "OTP ID",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
{
|
||||
"name": "body",
|
||||
"in": "body",
|
||||
"description": "OTP update request",
|
||||
"required": true,
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/UpdateOTPRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "OTP updated",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/OTP",
|
||||
},
|
||||
},
|
||||
"400": map[string]interface{}{
|
||||
"description": "Invalid request",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"404": map[string]interface{}{
|
||||
"description": "OTP not found",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"delete": map[string]interface{}{
|
||||
"summary": "Delete OTP",
|
||||
"description": "Delete an existing OTP",
|
||||
"tags": []string{"otp"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"description": "OTP ID",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "OTP deleted",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Success message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"404": map[string]interface{}{
|
||||
"description": "OTP not found",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/otp/{id}/code": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "Get OTP code",
|
||||
"description": "Get the current OTP code",
|
||||
"tags": []string{"otp"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"description": "OTP ID",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "OTP code",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/OTPCodeResponse",
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"404": map[string]interface{}{
|
||||
"description": "OTP not found",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/otp/{id}/verify": map[string]interface{}{
|
||||
"post": map[string]interface{}{
|
||||
"summary": "Verify OTP code",
|
||||
"description": "Verify an OTP code",
|
||||
"tags": []string{"otp"},
|
||||
"security": []map[string]interface{}{
|
||||
{
|
||||
"Bearer": []string{},
|
||||
},
|
||||
},
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"description": "OTP ID",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
{
|
||||
"name": "body",
|
||||
"in": "body",
|
||||
"description": "OTP verification request",
|
||||
"required": true,
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/VerifyOTPRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "OTP verification result",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/VerifyOTPResponse",
|
||||
},
|
||||
},
|
||||
"400": map[string]interface{}{
|
||||
"description": "Invalid request",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "Unauthorized",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"404": map[string]interface{}{
|
||||
"description": "OTP not found",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
"500": map[string]interface{}{
|
||||
"description": "Internal server error",
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/definitions/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/health": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "Health check",
|
||||
"description": "Check if the API is healthy",
|
||||
"tags": []string{"system"},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "API is healthy",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"status": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Health status",
|
||||
},
|
||||
"time": map[string]interface{}{
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Current time",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/metrics": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "Metrics",
|
||||
"description": "Get application metrics",
|
||||
"tags": []string{"system"},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "Application metrics",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/swagger.json": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "Swagger JSON",
|
||||
"description": "Get Swagger JSON",
|
||||
"tags": []string{"system"},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "Swagger JSON",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns an HTTP handler for Swagger JSON
|
||||
func Handler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(SwaggerJSON())
|
||||
}
|
||||
}
|
24
go.mod
24
go.mod
|
@ -1,6 +1,8 @@
|
|||
module otpm
|
||||
|
||||
go 1.21.1
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.9
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
|
@ -14,17 +16,30 @@ require (
|
|||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.26.0 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/prometheus/client_golang v1.22.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
|
@ -35,9 +50,12 @@ require (
|
|||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
|
||||
golang.org/x/sys v0.22.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
|
||||
|
|
38
go.sum
38
go.sum
|
@ -1,5 +1,9 @@
|
|||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
@ -11,10 +15,21 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
|
|||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
|
||||
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
|
@ -33,6 +48,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
|||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
|
@ -43,6 +60,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o
|
|||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
|
@ -50,6 +69,14 @@ github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
||||
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
|
@ -87,17 +114,28 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
|||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
|
||||
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
||||
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
|
||||
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
|
||||
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
|
147
handlers/auth_handler.go
Normal file
147
handlers/auth_handler.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"otpm/api"
|
||||
"otpm/services"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication related requests
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a login response
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
OpenID string `json:"openid"`
|
||||
}
|
||||
|
||||
// Login handles WeChat login
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Limit request body size to prevent DOS
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
|
||||
|
||||
// Parse request
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
fmt.Sprintf("Invalid request body: %v", err))
|
||||
log.Printf("Login request parse error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Code == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Code is required")
|
||||
log.Printf("Login request validation failed: empty code")
|
||||
return
|
||||
}
|
||||
|
||||
// Login with WeChat code
|
||||
token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
log.Printf("Login failed for code %s: %v", req.Code, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful login
|
||||
log.Printf("Login successful for code %s (took %v)",
|
||||
req.Code, time.Since(start))
|
||||
|
||||
// Return token
|
||||
api.NewResponseWriter(w).WriteSuccess(LoginResponse{
|
||||
Token: token,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyToken handles token verification
|
||||
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Get token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Authorization header is required")
|
||||
log.Printf("Token verification failed: missing Authorization header")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid token format. Expected 'Bearer <token>'")
|
||||
log.Printf("Token verification failed: invalid token format")
|
||||
return
|
||||
}
|
||||
|
||||
token := authHeader[7:]
|
||||
if len(token) < 32 { // Basic length check
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid token length")
|
||||
log.Printf("Token verification failed: token too short")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token
|
||||
claims, err := h.authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("Token verification failed for token %s: %v",
|
||||
maskToken(token), err) // Mask token in logs
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful verification
|
||||
userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
|
||||
if !ok {
|
||||
log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
|
||||
} else {
|
||||
log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
|
||||
}
|
||||
|
||||
// Token is valid
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
|
||||
"valid": true,
|
||||
})
|
||||
}
|
||||
|
||||
// maskToken masks sensitive parts of token for logging
|
||||
func maskToken(token string) string {
|
||||
if len(token) < 8 {
|
||||
return "****"
|
||||
}
|
||||
return token[:4] + "****" + token[len(token)-4:]
|
||||
}
|
||||
|
||||
// Routes returns all routes for the auth handler
|
||||
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/login": h.Login,
|
||||
"/verify-token": h.VerifyToken,
|
||||
}
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
DB *sqlx.DB
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var code2Session = "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code"
|
||||
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// 封装code2session接口返回数据
|
||||
type LoginResponse struct {
|
||||
OpenId string `json:"openid"`
|
||||
SessionKey string `json:"session_key"`
|
||||
UnionId string `json:"unionid"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
func getLoginResponse(code string) (*LoginResponse, error) {
|
||||
appid := viper.GetString("wechat.appid")
|
||||
secret := viper.GetString("wechat.secret")
|
||||
url := fmt.Sprintf(code2Session, appid, secret, code)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var loginResponse LoginResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if loginResponse.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
|
||||
}
|
||||
return &loginResponse, nil
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var req LoginRequest
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, "Failed to parse request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loginResponse, err := getLoginResponse(req.Code)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get session key", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// // 插入或更新用户的openid和sessionid
|
||||
// query := `
|
||||
// INSERT INTO users (openid, sessionid)
|
||||
// VALUES ($1, $2)
|
||||
// ON CONFLICT (openid) DO UPDATE SET sessionid = $2
|
||||
// RETURNING id;
|
||||
// `
|
||||
|
||||
// var ID int
|
||||
// if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
|
||||
// http.Error(w, "Failed to log in user", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
|
||||
data := map[string]interface{}{
|
||||
"openid": loginResponse.OpenId,
|
||||
"session_key": loginResponse.SessionKey,
|
||||
}
|
||||
respData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to marshal response data", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(respData))
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OtpRequest struct {
|
||||
OpenID string `json:"openid"`
|
||||
Num int `json:"num"`
|
||||
Token *[]OTP `json:"token"`
|
||||
}
|
||||
type OTP struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Remark string `json:"remark"`
|
||||
Secret string `json:"secret"`
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) {
|
||||
var req OtpRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
num := len(*req.Token)
|
||||
|
||||
// 插入或更新 OTP 记录
|
||||
query := `
|
||||
INSERT INTO otp (openid, num, token)
|
||||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
_, err := h.DB.Exec(query, req.OpenID, req.Token, num)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to update or create OTP", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OTP updated or created successfully"))
|
||||
}
|
||||
|
||||
func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
|
||||
openid := r.URL.Query().Get("openid")
|
||||
if openid == "" {
|
||||
http.Error(w, "未登录", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var otp OtpRequest
|
||||
|
||||
err := h.DB.Get(&otp, "SELECT token, num, openid FROM otp WHERE openid=$1", openid)
|
||||
if err != nil {
|
||||
http.Error(w, "OTP not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(otp)
|
||||
}
|
286
handlers/otp_handler.go
Normal file
286
handlers/otp_handler.go
Normal file
|
@ -0,0 +1,286 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"otpm/api"
|
||||
"otpm/middleware"
|
||||
"otpm/models"
|
||||
"otpm/services"
|
||||
)
|
||||
|
||||
// OTPHandler handles OTP related requests
|
||||
type OTPHandler struct {
|
||||
otpService *services.OTPService
|
||||
}
|
||||
|
||||
// NewOTPHandler creates a new OTPHandler
|
||||
func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
|
||||
return &OTPHandler{
|
||||
otpService: otpService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOTPRequest represents a request to create an OTP
|
||||
type CreateOTPRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
Secret string `json:"secret"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
}
|
||||
|
||||
// CreateOTP handles OTP creation
|
||||
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Limit request body size
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
|
||||
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("CreateOTP unauthorized attempt")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request
|
||||
var req CreateOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
fmt.Sprintf("Invalid request body: %v", err))
|
||||
log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate OTP parameters
|
||||
if req.Secret == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Secret is required")
|
||||
log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate algorithm
|
||||
supportedAlgos := map[string]bool{
|
||||
"SHA1": true,
|
||||
"SHA256": true,
|
||||
"SHA512": true,
|
||||
}
|
||||
if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
|
||||
log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
|
||||
userID, req.Algorithm)
|
||||
return
|
||||
}
|
||||
|
||||
// Create OTP
|
||||
otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
Secret: req.Secret,
|
||||
Algorithm: req.Algorithm,
|
||||
Digits: req.Digits,
|
||||
Period: req.Period,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
|
||||
log.Printf("CreateOTP failed for user %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful creation (mask secret in logs)
|
||||
log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
|
||||
userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(otp)
|
||||
}
|
||||
|
||||
// ListOTPs handles listing all OTPs for a user
|
||||
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTPs
|
||||
otps, err := h.otpService.ListOTPs(r.Context(), userID)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(otps)
|
||||
}
|
||||
|
||||
// GetOTPCode handles generating OTP code
|
||||
func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
otpID = strings.TrimSuffix(otpID, "/code")
|
||||
|
||||
// Validate OTP ID format
|
||||
if len(otpID) != 36 { // Assuming UUID format
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid OTP ID format")
|
||||
log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting check could be added here
|
||||
// (would require redis or similar rate limiter)
|
||||
|
||||
// Generate code
|
||||
code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful generation (without actual code)
|
||||
log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
|
||||
userID, otpID, time.Since(start), expiresIn)
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
|
||||
"code": code,
|
||||
"expires_in": expiresIn,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyOTPRequest represents a request to verify an OTP code
|
||||
type VerifyOTPRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// VerifyOTP handles OTP code verification
|
||||
func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
otpID = strings.TrimSuffix(otpID, "/verify")
|
||||
|
||||
// Parse request
|
||||
var req VerifyOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify code
|
||||
valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
|
||||
"valid": valid,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOTPRequest represents a request to update an OTP
|
||||
type UpdateOTPRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
}
|
||||
|
||||
// UpdateOTP handles OTP update
|
||||
func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
|
||||
// Parse request
|
||||
var req UpdateOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Update OTP
|
||||
otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
Algorithm: req.Algorithm,
|
||||
Digits: req.Digits,
|
||||
Period: req.Period,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(otp)
|
||||
}
|
||||
|
||||
// DeleteOTP handles OTP deletion
|
||||
func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
|
||||
// Delete OTP
|
||||
if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]string{
|
||||
"message": "OTP deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Routes returns all routes for the OTP handler
|
||||
func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/otp": h.CreateOTP,
|
||||
"/otp/": h.ListOTPs,
|
||||
"/otp/{id}": h.UpdateOTP,
|
||||
"/otp/{id}/code": h.GetOTPCode,
|
||||
"/otp/{id}/verify": h.VerifyOTP,
|
||||
}
|
||||
}
|
204
logger/logger.go
Normal file
204
logger/logger.go
Normal file
|
@ -0,0 +1,204 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Level represents a log level
|
||||
type Level int
|
||||
|
||||
const (
|
||||
// DEBUG level
|
||||
DEBUG Level = iota
|
||||
// INFO level
|
||||
INFO
|
||||
// WARN level
|
||||
WARN
|
||||
// ERROR level
|
||||
ERROR
|
||||
// FATAL level
|
||||
FATAL
|
||||
)
|
||||
|
||||
// String returns the string representation of the log level
|
||||
func (l Level) String() string {
|
||||
switch l {
|
||||
case DEBUG:
|
||||
return "DEBUG"
|
||||
case INFO:
|
||||
return "INFO"
|
||||
case WARN:
|
||||
return "WARN"
|
||||
case ERROR:
|
||||
return "ERROR"
|
||||
case FATAL:
|
||||
return "FATAL"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// Logger represents a logger
|
||||
type Logger struct {
|
||||
level Level
|
||||
output io.Writer
|
||||
}
|
||||
|
||||
// contextKey is a type for context keys
|
||||
type contextKey string
|
||||
|
||||
// requestIDKey is the key for request ID in context
|
||||
const requestIDKey = contextKey("request_id")
|
||||
|
||||
// New creates a new logger
|
||||
func New(level Level, output io.Writer) *Logger {
|
||||
if output == nil {
|
||||
output = os.Stdout
|
||||
}
|
||||
return &Logger{
|
||||
level: level,
|
||||
output: output,
|
||||
}
|
||||
}
|
||||
|
||||
// WithLevel creates a new logger with the specified level
|
||||
func (l *Logger) WithLevel(level Level) *Logger {
|
||||
return &Logger{
|
||||
level: level,
|
||||
output: l.output,
|
||||
}
|
||||
}
|
||||
|
||||
// WithOutput creates a new logger with the specified output
|
||||
func (l *Logger) WithOutput(output io.Writer) *Logger {
|
||||
return &Logger{
|
||||
level: l.level,
|
||||
output: output,
|
||||
}
|
||||
}
|
||||
|
||||
// log logs a message with the specified level
|
||||
func (l *Logger) log(ctx context.Context, level Level, format string, args ...interface{}) {
|
||||
if level < l.level {
|
||||
return
|
||||
}
|
||||
|
||||
// Get request ID from context
|
||||
requestID := getRequestID(ctx)
|
||||
|
||||
// Get caller information
|
||||
_, file, line, ok := runtime.Caller(2)
|
||||
if !ok {
|
||||
file = "unknown"
|
||||
line = 0
|
||||
}
|
||||
// Extract just the filename
|
||||
if idx := strings.LastIndex(file, "/"); idx >= 0 {
|
||||
file = file[idx+1:]
|
||||
}
|
||||
|
||||
// Format message
|
||||
message := fmt.Sprintf(format, args...)
|
||||
|
||||
// Format log entry
|
||||
timestamp := time.Now().Format(time.RFC3339)
|
||||
logEntry := fmt.Sprintf("%s [%s] %s:%d [%s] %s\n",
|
||||
timestamp, level.String(), file, line, requestID, message)
|
||||
|
||||
// Write log entry
|
||||
_, _ = l.output.Write([]byte(logEntry))
|
||||
|
||||
// Exit if fatal
|
||||
if level == FATAL {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
|
||||
l.log(ctx, DEBUG, format, args...)
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
|
||||
l.log(ctx, INFO, format, args...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
|
||||
l.log(ctx, WARN, format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
|
||||
l.log(ctx, ERROR, format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal message and exits
|
||||
func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
|
||||
l.log(ctx, FATAL, format, args...)
|
||||
}
|
||||
|
||||
// WithRequestID adds a request ID to the context
|
||||
func WithRequestID(ctx context.Context) context.Context {
|
||||
requestID := uuid.New().String()
|
||||
return context.WithValue(ctx, requestIDKey, requestID)
|
||||
}
|
||||
|
||||
// GetRequestID gets the request ID from the context
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
return getRequestID(ctx)
|
||||
}
|
||||
|
||||
// getRequestID gets the request ID from the context
|
||||
func getRequestID(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return "-"
|
||||
}
|
||||
requestID, ok := ctx.Value(requestIDKey).(string)
|
||||
if !ok {
|
||||
return "-"
|
||||
}
|
||||
return requestID
|
||||
}
|
||||
|
||||
// Default logger
|
||||
var defaultLogger = New(INFO, os.Stdout)
|
||||
|
||||
// SetDefaultLogger sets the default logger
|
||||
func SetDefaultLogger(logger *Logger) {
|
||||
defaultLogger = logger
|
||||
}
|
||||
|
||||
// Debug logs a debug message using the default logger
|
||||
func Debug(ctx context.Context, format string, args ...interface{}) {
|
||||
defaultLogger.Debug(ctx, format, args...)
|
||||
}
|
||||
|
||||
// Info logs an info message using the default logger
|
||||
func Info(ctx context.Context, format string, args ...interface{}) {
|
||||
defaultLogger.Info(ctx, format, args...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message using the default logger
|
||||
func Warn(ctx context.Context, format string, args ...interface{}) {
|
||||
defaultLogger.Warn(ctx, format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message using the default logger
|
||||
func Error(ctx context.Context, format string, args ...interface{}) {
|
||||
defaultLogger.Error(ctx, format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal message and exits using the default logger
|
||||
func Fatal(ctx context.Context, format string, args ...interface{}) {
|
||||
defaultLogger.Fatal(ctx, format, args...)
|
||||
}
|
193
metrics/metrics.go
Normal file
193
metrics/metrics.go
Normal file
|
@ -0,0 +1,193 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default metrics
|
||||
requestDuration = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "http_request_duration_seconds",
|
||||
Help: "Duration of HTTP requests in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
)
|
||||
|
||||
requestTotal = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "http_requests_total",
|
||||
Help: "Total number of HTTP requests",
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
)
|
||||
|
||||
otpGenerationTotal = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "otp_generation_total",
|
||||
Help: "Total number of OTP generations",
|
||||
},
|
||||
[]string{"user_id", "otp_id"},
|
||||
)
|
||||
|
||||
otpVerificationTotal = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "otp_verification_total",
|
||||
Help: "Total number of OTP verifications",
|
||||
},
|
||||
[]string{"user_id", "otp_id", "success"},
|
||||
)
|
||||
|
||||
activeUsers = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "active_users",
|
||||
Help: "Number of active users",
|
||||
},
|
||||
)
|
||||
|
||||
cacheHits = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_hits_total",
|
||||
Help: "Total number of cache hits",
|
||||
},
|
||||
[]string{"cache"},
|
||||
)
|
||||
|
||||
cacheMisses = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_misses_total",
|
||||
Help: "Total number of cache misses",
|
||||
},
|
||||
[]string{"cache"},
|
||||
)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register metrics with prometheus
|
||||
prometheus.MustRegister(
|
||||
requestDuration,
|
||||
requestTotal,
|
||||
otpGenerationTotal,
|
||||
otpVerificationTotal,
|
||||
activeUsers,
|
||||
cacheHits,
|
||||
cacheMisses,
|
||||
)
|
||||
}
|
||||
|
||||
// MetricsService provides metrics functionality
|
||||
type MetricsService struct {
|
||||
activeUsersMutex sync.RWMutex
|
||||
activeUserIDs map[string]bool
|
||||
}
|
||||
|
||||
// NewMetricsService creates a new MetricsService
|
||||
func NewMetricsService() *MetricsService {
|
||||
return &MetricsService{
|
||||
activeUserIDs: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns an HTTP handler for metrics
|
||||
func (s *MetricsService) Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
}
|
||||
|
||||
// RecordRequest records metrics for an HTTP request
|
||||
func (s *MetricsService) RecordRequest(method, path string, status int, duration time.Duration) {
|
||||
labels := prometheus.Labels{
|
||||
"method": method,
|
||||
"path": path,
|
||||
"status": fmt.Sprintf("%d", status),
|
||||
}
|
||||
|
||||
requestDuration.With(labels).Observe(duration.Seconds())
|
||||
requestTotal.With(labels).Inc()
|
||||
}
|
||||
|
||||
// RecordOTPGeneration records metrics for OTP generation
|
||||
func (s *MetricsService) RecordOTPGeneration(userID, otpID string) {
|
||||
otpGenerationTotal.With(prometheus.Labels{
|
||||
"user_id": userID,
|
||||
"otp_id": otpID,
|
||||
}).Inc()
|
||||
}
|
||||
|
||||
// RecordOTPVerification records metrics for OTP verification
|
||||
func (s *MetricsService) RecordOTPVerification(userID, otpID string, success bool) {
|
||||
otpVerificationTotal.With(prometheus.Labels{
|
||||
"user_id": userID,
|
||||
"otp_id": otpID,
|
||||
"success": fmt.Sprintf("%t", success),
|
||||
}).Inc()
|
||||
}
|
||||
|
||||
// RecordUserActivity records user activity
|
||||
func (s *MetricsService) RecordUserActivity(userID string) {
|
||||
s.activeUsersMutex.Lock()
|
||||
defer s.activeUsersMutex.Unlock()
|
||||
|
||||
if !s.activeUserIDs[userID] {
|
||||
s.activeUserIDs[userID] = true
|
||||
activeUsers.Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUserInactivity records user inactivity
|
||||
func (s *MetricsService) RecordUserInactivity(userID string) {
|
||||
s.activeUsersMutex.Lock()
|
||||
defer s.activeUsersMutex.Unlock()
|
||||
|
||||
if s.activeUserIDs[userID] {
|
||||
delete(s.activeUserIDs, userID)
|
||||
activeUsers.Dec()
|
||||
}
|
||||
}
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
func (s *MetricsService) RecordCacheHit(cache string) {
|
||||
cacheHits.With(prometheus.Labels{
|
||||
"cache": cache,
|
||||
}).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheMiss records a cache miss
|
||||
func (s *MetricsService) RecordCacheMiss(cache string) {
|
||||
cacheMisses.With(prometheus.Labels{
|
||||
"cache": cache,
|
||||
}).Inc()
|
||||
}
|
||||
|
||||
// Middleware creates a middleware that records request metrics
|
||||
func (s *MetricsService) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Create response writer that captures status code
|
||||
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
|
||||
// Call next handler
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// Record metrics
|
||||
s.RecordRequest(r.Method, r.URL.Path, rw.status, time.Since(start))
|
||||
})
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.status = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
353
middleware/middleware.go
Normal file
353
middleware/middleware.go
Normal file
|
@ -0,0 +1,353 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// Response represents a standard API response
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse sends a JSON error response
|
||||
func ErrorResponse(w http.ResponseWriter, code int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// SuccessResponse sends a JSON success response
|
||||
func SuccessResponse(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Code: http.StatusOK,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Logger is a middleware that logs request details with structured format
|
||||
func Logger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
r.Header.Set("X-Request-ID", requestID)
|
||||
}
|
||||
|
||||
// Create a custom response writer to capture status code
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Process request
|
||||
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), "request_id", requestID)))
|
||||
|
||||
// Log structured request details
|
||||
log.Printf(
|
||||
"method=%s path=%s status=%d duration=%s ip=%s request_id=%s",
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
rw.status,
|
||||
time.Since(start).String(),
|
||||
r.RemoteAddr,
|
||||
requestID,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// generateRequestID creates a unique request identifier
|
||||
func generateRequestID() string {
|
||||
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
|
||||
}
|
||||
|
||||
// randomString generates a random string of given length
|
||||
func randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Recover is a middleware that recovers from panics with detailed logging
|
||||
func Recover(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// Get request ID from context
|
||||
requestID := ""
|
||||
if ctx := r.Context(); ctx != nil {
|
||||
if id, ok := ctx.Value("request_id").(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Log error with stack trace and request context
|
||||
log.Printf(
|
||||
"panic: %v\nrequest_id=%s\nmethod=%s\npath=%s\nremote_addr=%s\nstack:\n%s",
|
||||
err,
|
||||
requestID,
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
r.RemoteAddr,
|
||||
debug.Stack(),
|
||||
)
|
||||
|
||||
// Determine error type
|
||||
var message string
|
||||
var status int
|
||||
|
||||
switch e := err.(type) {
|
||||
case error:
|
||||
message = e.Error()
|
||||
if isClientError(e) {
|
||||
status = http.StatusBadRequest
|
||||
} else {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
case string:
|
||||
message = e
|
||||
status = http.StatusInternalServerError
|
||||
default:
|
||||
message = "Internal Server Error"
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
ErrorResponse(w, status, message)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// isClientError checks if error should be treated as client error
|
||||
func isClientError(err error) bool {
|
||||
// Add more client error types as needed
|
||||
return strings.Contains(err.Error(), "validation") ||
|
||||
strings.Contains(err.Error(), "invalid") ||
|
||||
strings.Contains(err.Error(), "missing")
|
||||
}
|
||||
|
||||
// CORS is a middleware that handles CORS
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Timeout is a middleware that safely handles request timeouts
|
||||
func Timeout(duration time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), duration)
|
||||
defer cancel()
|
||||
|
||||
// Use buffered channels to prevent goroutine leaks
|
||||
done := make(chan struct{}, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
|
||||
// Track request processing in goroutine
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicChan <- p
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for completion, timeout or panic
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case p := <-panicChan:
|
||||
panic(p) // Re-throw panic to be caught by Recover middleware
|
||||
case <-ctx.Done():
|
||||
// Get request context for logging
|
||||
requestID := ""
|
||||
if ctx := r.Context(); ctx != nil {
|
||||
if id, ok := ctx.Value("request_id").(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Log timeout details
|
||||
log.Printf(
|
||||
"request_timeout: request_id=%s method=%s path=%s timeout=%s",
|
||||
requestID,
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
duration.String(),
|
||||
)
|
||||
|
||||
// Send timeout response
|
||||
ErrorResponse(w, http.StatusGatewayTimeout, fmt.Sprintf(
|
||||
"Request timed out after %s", duration.String(),
|
||||
))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Auth is a middleware that validates JWT tokens with enhanced security
|
||||
func Auth(jwtSecret string, requiredRoles ...string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get request ID for logging
|
||||
requestID := ""
|
||||
if ctx := r.Context(); ctx != nil {
|
||||
if id, ok := ctx.Value("request_id").(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Get token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
log.Printf("auth_failed: request_id=%s error=missing_authorization_header", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Authorization header is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate header format
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
log.Printf("auth_failed: request_id=%s error=invalid_header_format", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Authorization header format must be 'Bearer <token>'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Parse and validate token
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("auth_failed: request_id=%s error=token_parse_failed reason=%v", requestID, err)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
log.Printf("auth_failed: request_id=%s error=invalid_token", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
log.Printf("auth_failed: request_id=%s error=invalid_claims", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid token claims")
|
||||
return
|
||||
}
|
||||
|
||||
// Check required claims
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
log.Printf("auth_failed: request_id=%s error=missing_user_id", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
// Check token expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
log.Printf("auth_failed: request_id=%s error=token_expired", requestID)
|
||||
ErrorResponse(w, http.StatusUnauthorized, "Token has expired")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check required roles if specified
|
||||
if len(requiredRoles) > 0 {
|
||||
roles, ok := claims["roles"].([]interface{})
|
||||
if !ok {
|
||||
log.Printf("auth_failed: request_id=%s error=missing_roles", requestID)
|
||||
ErrorResponse(w, http.StatusForbidden, "Access denied: missing roles")
|
||||
return
|
||||
}
|
||||
|
||||
hasRequiredRole := false
|
||||
for _, requiredRole := range requiredRoles {
|
||||
for _, role := range roles {
|
||||
if r, ok := role.(string); ok && r == requiredRole {
|
||||
hasRequiredRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRequiredRole {
|
||||
log.Printf("auth_failed: request_id=%s error=insufficient_permissions", requestID)
|
||||
ErrorResponse(w, http.StatusForbidden, "Access denied: insufficient permissions")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims to context
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, "user_id", userID)
|
||||
ctx = context.WithValue(ctx, "claims", claims)
|
||||
|
||||
log.Printf("auth_success: request_id=%s user_id=%s", requestID, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// responseWriter is a custom response writer that captures the status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.status = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from the request context
|
||||
func GetUserID(r *http.Request) (string, error) {
|
||||
userID, ok := r.Context().Value("user_id").(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("user ID not found in context")
|
||||
}
|
||||
return userID, nil
|
||||
}
|
50
miniprogram-example/app.js
Normal file
50
miniprogram-example/app.js
Normal file
|
@ -0,0 +1,50 @@
|
|||
// app.js
|
||||
App({
|
||||
onLaunch() {
|
||||
// 检查更新
|
||||
if (wx.canIUse('getUpdateManager')) {
|
||||
const updateManager = wx.getUpdateManager();
|
||||
updateManager.onCheckForUpdate(function (res) {
|
||||
if (res.hasUpdate) {
|
||||
updateManager.onUpdateReady(function () {
|
||||
wx.showModal({
|
||||
title: '更新提示',
|
||||
content: '新版本已经准备好,是否重启应用?',
|
||||
success: function (res) {
|
||||
if (res.confirm) {
|
||||
updateManager.applyUpdate();
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
updateManager.onUpdateFailed(function () {
|
||||
wx.showModal({
|
||||
title: '更新提示',
|
||||
content: '新版本下载失败,请检查网络后重试',
|
||||
showCancel: false
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 获取系统信息
|
||||
try {
|
||||
const systemInfo = wx.getSystemInfoSync();
|
||||
this.globalData.systemInfo = systemInfo;
|
||||
|
||||
// 计算安全区域
|
||||
const { screenHeight, safeArea } = systemInfo;
|
||||
this.globalData.safeAreaBottom = screenHeight - safeArea.bottom;
|
||||
} catch (e) {
|
||||
console.error('获取系统信息失败', e);
|
||||
}
|
||||
},
|
||||
|
||||
globalData: {
|
||||
userInfo: null,
|
||||
systemInfo: {},
|
||||
safeAreaBottom: 0
|
||||
}
|
||||
});
|
23
miniprogram-example/app.json
Normal file
23
miniprogram-example/app.json
Normal file
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"pages": [
|
||||
"pages/login/login",
|
||||
"pages/otp-list/index",
|
||||
"pages/otp-add/index"
|
||||
],
|
||||
"window": {
|
||||
"backgroundTextStyle": "light",
|
||||
"navigationBarBackgroundColor": "#fff",
|
||||
"navigationBarTitleText": "OTPM",
|
||||
"navigationBarTextStyle": "black",
|
||||
"backgroundColor": "#F8F8F8"
|
||||
},
|
||||
"permission": {
|
||||
"scope.camera": {
|
||||
"desc": "需要使用相机扫描二维码"
|
||||
}
|
||||
},
|
||||
"usingComponents": {},
|
||||
"style": "v2",
|
||||
"sitemapLocation": "sitemap.json",
|
||||
"lazyCodeLoading": "requiredComponents"
|
||||
}
|
238
miniprogram-example/app.wxss
Normal file
238
miniprogram-example/app.wxss
Normal file
|
@ -0,0 +1,238 @@
|
|||
/**app.wxss**/
|
||||
page {
|
||||
--primary-color: #1890ff;
|
||||
--danger-color: #ff4d4f;
|
||||
--success-color: #52c41a;
|
||||
--warning-color: #faad14;
|
||||
--text-color: #333333;
|
||||
--text-color-secondary: #666666;
|
||||
--text-color-light: #999999;
|
||||
--border-color: #e8e8e8;
|
||||
--background-color: #f8f8f8;
|
||||
--border-radius: 8rpx;
|
||||
--safe-area-bottom: env(safe-area-inset-bottom);
|
||||
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Helvetica Neue', Helvetica,
|
||||
Segoe UI, Arial, Roboto, 'PingFang SC', 'miui', 'Hiragino Sans GB', 'Microsoft Yahei',
|
||||
sans-serif;
|
||||
font-size: 28rpx;
|
||||
line-height: 1.5;
|
||||
color: var(--text-color);
|
||||
background-color: var(--background-color);
|
||||
}
|
||||
|
||||
/* 清除默认样式 */
|
||||
button {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
background: none;
|
||||
border: none;
|
||||
text-align: left;
|
||||
line-height: inherit;
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
button::after {
|
||||
border: none;
|
||||
}
|
||||
|
||||
/* 通用样式类 */
|
||||
.container {
|
||||
min-height: 100vh;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.safe-area-bottom {
|
||||
padding-bottom: var(--safe-area-bottom);
|
||||
}
|
||||
|
||||
.flex-center {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.flex-between {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.flex-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.text-primary {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.text-danger {
|
||||
color: var(--danger-color);
|
||||
}
|
||||
|
||||
.text-success {
|
||||
color: var(--success-color);
|
||||
}
|
||||
|
||||
.text-warning {
|
||||
color: var(--warning-color);
|
||||
}
|
||||
|
||||
.text-secondary {
|
||||
color: var(--text-color-secondary);
|
||||
}
|
||||
|
||||
.text-light {
|
||||
color: var(--text-color-light);
|
||||
}
|
||||
|
||||
.text-center {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.text-left {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.text-right {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.text-bold {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.text-ellipsis {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.bg-white {
|
||||
background-color: #ffffff;
|
||||
}
|
||||
|
||||
.bg-primary {
|
||||
background-color: var(--primary-color);
|
||||
}
|
||||
|
||||
.bg-danger {
|
||||
background-color: var(--danger-color);
|
||||
}
|
||||
|
||||
.bg-success {
|
||||
background-color: var(--success-color);
|
||||
}
|
||||
|
||||
.bg-warning {
|
||||
background-color: var(--warning-color);
|
||||
}
|
||||
|
||||
.shadow {
|
||||
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.rounded {
|
||||
border-radius: var(--border-radius);
|
||||
}
|
||||
|
||||
.border {
|
||||
border: 2rpx solid var(--border-color);
|
||||
}
|
||||
|
||||
.border-top {
|
||||
border-top: 2rpx solid var(--border-color);
|
||||
}
|
||||
|
||||
.border-bottom {
|
||||
border-bottom: 2rpx solid var(--border-color);
|
||||
}
|
||||
|
||||
/* 动画类 */
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
.fade-out {
|
||||
animation: fadeOut 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeOut {
|
||||
from {
|
||||
opacity: 1;
|
||||
}
|
||||
to {
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* 间距类 */
|
||||
.m-0 { margin: 0; }
|
||||
.m-1 { margin: 10rpx; }
|
||||
.m-2 { margin: 20rpx; }
|
||||
.m-3 { margin: 30rpx; }
|
||||
.m-4 { margin: 40rpx; }
|
||||
|
||||
.mt-0 { margin-top: 0; }
|
||||
.mt-1 { margin-top: 10rpx; }
|
||||
.mt-2 { margin-top: 20rpx; }
|
||||
.mt-3 { margin-top: 30rpx; }
|
||||
.mt-4 { margin-top: 40rpx; }
|
||||
|
||||
.mb-0 { margin-bottom: 0; }
|
||||
.mb-1 { margin-bottom: 10rpx; }
|
||||
.mb-2 { margin-bottom: 20rpx; }
|
||||
.mb-3 { margin-bottom: 30rpx; }
|
||||
.mb-4 { margin-bottom: 40rpx; }
|
||||
|
||||
.ml-0 { margin-left: 0; }
|
||||
.ml-1 { margin-left: 10rpx; }
|
||||
.ml-2 { margin-left: 20rpx; }
|
||||
.ml-3 { margin-left: 30rpx; }
|
||||
.ml-4 { margin-left: 40rpx; }
|
||||
|
||||
.mr-0 { margin-right: 0; }
|
||||
.mr-1 { margin-right: 10rpx; }
|
||||
.mr-2 { margin-right: 20rpx; }
|
||||
.mr-3 { margin-right: 30rpx; }
|
||||
.mr-4 { margin-right: 40rpx; }
|
||||
|
||||
.p-0 { padding: 0; }
|
||||
.p-1 { padding: 10rpx; }
|
||||
.p-2 { padding: 20rpx; }
|
||||
.p-3 { padding: 30rpx; }
|
||||
.p-4 { padding: 40rpx; }
|
||||
|
||||
.pt-0 { padding-top: 0; }
|
||||
.pt-1 { padding-top: 10rpx; }
|
||||
.pt-2 { padding-top: 20rpx; }
|
||||
.pt-3 { padding-top: 30rpx; }
|
||||
.pt-4 { padding-top: 40rpx; }
|
||||
|
||||
.pb-0 { padding-bottom: 0; }
|
||||
.pb-1 { padding-bottom: 10rpx; }
|
||||
.pb-2 { padding-bottom: 20rpx; }
|
||||
.pb-3 { padding-bottom: 30rpx; }
|
||||
.pb-4 { padding-bottom: 40rpx; }
|
||||
|
||||
.pl-0 { padding-left: 0; }
|
||||
.pl-1 { padding-left: 10rpx; }
|
||||
.pl-2 { padding-left: 20rpx; }
|
||||
.pl-3 { padding-left: 30rpx; }
|
||||
.pl-4 { padding-left: 40rpx; }
|
||||
|
||||
.pr-0 { padding-right: 0; }
|
||||
.pr-1 { padding-right: 10rpx; }
|
||||
.pr-2 { padding-right: 20rpx; }
|
||||
.pr-3 { padding-right: 30rpx; }
|
||||
.pr-4 { padding-right: 40rpx; }
|
48
miniprogram-example/pages/login/login.js
Normal file
48
miniprogram-example/pages/login/login.js
Normal file
|
@ -0,0 +1,48 @@
|
|||
// login.js
|
||||
import { wxLogin } from '../../services/auth';
|
||||
|
||||
Page({
|
||||
data: {
|
||||
loading: false
|
||||
},
|
||||
|
||||
onLoad() {
|
||||
// 页面加载时检查是否已经登录
|
||||
const token = wx.getStorageSync('token');
|
||||
if (token) {
|
||||
this.redirectToHome();
|
||||
}
|
||||
},
|
||||
|
||||
// 处理登录按钮点击
|
||||
handleLogin() {
|
||||
if (this.data.loading) return;
|
||||
|
||||
this.setData({ loading: true });
|
||||
|
||||
wxLogin()
|
||||
.then(() => {
|
||||
wx.showToast({
|
||||
title: '登录成功',
|
||||
icon: 'success'
|
||||
});
|
||||
this.redirectToHome();
|
||||
})
|
||||
.catch(err => {
|
||||
wx.showToast({
|
||||
title: err.message || '登录失败',
|
||||
icon: 'none'
|
||||
});
|
||||
})
|
||||
.finally(() => {
|
||||
this.setData({ loading: false });
|
||||
});
|
||||
},
|
||||
|
||||
// 跳转到首页
|
||||
redirectToHome() {
|
||||
wx.reLaunch({
|
||||
url: '/pages/otp-list/index'
|
||||
});
|
||||
}
|
||||
});
|
3
miniprogram-example/pages/login/login.json
Normal file
3
miniprogram-example/pages/login/login.json
Normal file
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"usingComponents": {}
|
||||
}
|
30
miniprogram-example/pages/login/login.wxml
Normal file
30
miniprogram-example/pages/login/login.wxml
Normal file
|
@ -0,0 +1,30 @@
|
|||
<!-- login.wxml -->
|
||||
<view class="container">
|
||||
<view class="logo-container">
|
||||
<image class="logo" src="/assets/images/logo.png" mode="aspectFit"></image>
|
||||
<text class="app-name">OTPM 小程序</text>
|
||||
</view>
|
||||
|
||||
<view class="login-container">
|
||||
<text class="login-title">欢迎使用 OTPM</text>
|
||||
<text class="login-subtitle">一次性密码管理工具</text>
|
||||
|
||||
<button
|
||||
class="login-button {{loading ? 'loading' : ''}}"
|
||||
type="primary"
|
||||
bindtap="handleLogin"
|
||||
disabled="{{loading}}"
|
||||
>
|
||||
<text wx:if="{{!loading}}">微信一键登录</text>
|
||||
<view wx:else class="loading-container">
|
||||
<view class="loading-icon"></view>
|
||||
<text>登录中...</text>
|
||||
</view>
|
||||
</button>
|
||||
|
||||
<view class="privacy-policy">
|
||||
<text>登录即表示您同意</text>
|
||||
<navigator url="/pages/privacy/index" class="policy-link">《隐私政策》</navigator>
|
||||
</view>
|
||||
</view>
|
||||
</view>
|
97
miniprogram-example/pages/login/login.wxss
Normal file
97
miniprogram-example/pages/login/login.wxss
Normal file
|
@ -0,0 +1,97 @@
|
|||
/* login.wxss */
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
height: 100vh;
|
||||
padding: 60rpx 40rpx;
|
||||
box-sizing: border-box;
|
||||
background-color: #f8f8f8;
|
||||
}
|
||||
|
||||
.logo-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
margin-top: 80rpx;
|
||||
}
|
||||
|
||||
.logo {
|
||||
width: 180rpx;
|
||||
height: 180rpx;
|
||||
margin-bottom: 20rpx;
|
||||
}
|
||||
|
||||
.app-name {
|
||||
font-size: 36rpx;
|
||||
font-weight: bold;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.login-container {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
margin-bottom: 100rpx;
|
||||
}
|
||||
|
||||
.login-title {
|
||||
font-size: 48rpx;
|
||||
font-weight: bold;
|
||||
color: #333;
|
||||
margin-bottom: 20rpx;
|
||||
}
|
||||
|
||||
.login-subtitle {
|
||||
font-size: 28rpx;
|
||||
color: #666;
|
||||
margin-bottom: 80rpx;
|
||||
}
|
||||
|
||||
.login-button {
|
||||
width: 80%;
|
||||
height: 88rpx;
|
||||
border-radius: 44rpx;
|
||||
font-size: 32rpx;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.login-button.loading {
|
||||
background-color: #8cc4ff;
|
||||
}
|
||||
|
||||
.loading-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.loading-icon {
|
||||
width: 36rpx;
|
||||
height: 36rpx;
|
||||
margin-right: 10rpx;
|
||||
border: 4rpx solid #ffffff;
|
||||
border-radius: 50%;
|
||||
border-top-color: transparent;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.privacy-policy {
|
||||
margin-top: 40rpx;
|
||||
font-size: 24rpx;
|
||||
color: #999;
|
||||
}
|
||||
|
||||
.policy-link {
|
||||
color: #1890ff;
|
||||
display: inline;
|
||||
}
|
169
miniprogram-example/pages/otp-add/index.js
Normal file
169
miniprogram-example/pages/otp-add/index.js
Normal file
|
@ -0,0 +1,169 @@
|
|||
// otp-add/index.js
|
||||
import { createOTP } from '../../services/otp';
|
||||
|
||||
Page({
|
||||
data: {
|
||||
form: {
|
||||
name: '',
|
||||
issuer: '',
|
||||
secret: '',
|
||||
algorithm: 'SHA1',
|
||||
digits: 6,
|
||||
period: 30
|
||||
},
|
||||
algorithms: ['SHA1', 'SHA256', 'SHA512'],
|
||||
digitOptions: [6, 8],
|
||||
periodOptions: [30, 60],
|
||||
submitting: false,
|
||||
scanMode: false
|
||||
},
|
||||
|
||||
// 处理输入变化
|
||||
handleInputChange(e) {
|
||||
const { field } = e.currentTarget.dataset;
|
||||
const { value } = e.detail;
|
||||
|
||||
this.setData({
|
||||
[`form.${field}`]: value
|
||||
});
|
||||
},
|
||||
|
||||
// 处理选择器变化
|
||||
handlePickerChange(e) {
|
||||
const { field } = e.currentTarget.dataset;
|
||||
const { value } = e.detail;
|
||||
|
||||
const options = this.data[`${field}Options`] || this.data[field];
|
||||
const selectedValue = options[value];
|
||||
|
||||
this.setData({
|
||||
[`form.${field}`]: selectedValue
|
||||
});
|
||||
},
|
||||
|
||||
// 扫描二维码
|
||||
handleScanQRCode() {
|
||||
this.setData({ scanMode: true });
|
||||
|
||||
wx.scanCode({
|
||||
scanType: ['qrCode'],
|
||||
success: (res) => {
|
||||
try {
|
||||
// 解析otpauth://协议的URL
|
||||
const url = res.result;
|
||||
if (url.startsWith('otpauth://totp/')) {
|
||||
const parsedUrl = new URL(url);
|
||||
const path = parsedUrl.pathname.substring(1); // 移除开头的斜杠
|
||||
|
||||
// 解析路径中的issuer和name
|
||||
let issuer = '';
|
||||
let name = path;
|
||||
|
||||
if (path.includes(':')) {
|
||||
const parts = path.split(':');
|
||||
issuer = parts[0];
|
||||
name = parts[1];
|
||||
}
|
||||
|
||||
// 从查询参数中获取其他信息
|
||||
const secret = parsedUrl.searchParams.get('secret') || '';
|
||||
const algorithm = parsedUrl.searchParams.get('algorithm') || 'SHA1';
|
||||
const digits = parseInt(parsedUrl.searchParams.get('digits') || '6');
|
||||
const period = parseInt(parsedUrl.searchParams.get('period') || '30');
|
||||
|
||||
// 如果查询参数中有issuer,优先使用
|
||||
if (parsedUrl.searchParams.get('issuer')) {
|
||||
issuer = parsedUrl.searchParams.get('issuer');
|
||||
}
|
||||
|
||||
this.setData({
|
||||
form: {
|
||||
name,
|
||||
issuer,
|
||||
secret,
|
||||
algorithm,
|
||||
digits,
|
||||
period
|
||||
}
|
||||
});
|
||||
|
||||
wx.showToast({
|
||||
title: '二维码解析成功',
|
||||
icon: 'success'
|
||||
});
|
||||
} else {
|
||||
wx.showToast({
|
||||
title: '不支持的二维码格式',
|
||||
icon: 'none'
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
wx.showToast({
|
||||
title: '二维码解析失败',
|
||||
icon: 'none'
|
||||
});
|
||||
}
|
||||
},
|
||||
fail: () => {
|
||||
wx.showToast({
|
||||
title: '扫描取消',
|
||||
icon: 'none'
|
||||
});
|
||||
},
|
||||
complete: () => {
|
||||
this.setData({ scanMode: false });
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// 提交表单
|
||||
handleSubmit() {
|
||||
const { form } = this.data;
|
||||
|
||||
// 表单验证
|
||||
if (!form.name) {
|
||||
wx.showToast({
|
||||
title: '请输入名称',
|
||||
icon: 'none'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (!form.secret) {
|
||||
wx.showToast({
|
||||
title: '请输入密钥',
|
||||
icon: 'none'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
this.setData({ submitting: true });
|
||||
|
||||
createOTP(form)
|
||||
.then(() => {
|
||||
wx.showToast({
|
||||
title: '添加成功',
|
||||
icon: 'success'
|
||||
});
|
||||
|
||||
// 返回上一页
|
||||
setTimeout(() => {
|
||||
wx.navigateBack();
|
||||
}, 1500);
|
||||
})
|
||||
.catch(err => {
|
||||
wx.showToast({
|
||||
title: err.message || '添加失败',
|
||||
icon: 'none'
|
||||
});
|
||||
})
|
||||
.finally(() => {
|
||||
this.setData({ submitting: false });
|
||||
});
|
||||
},
|
||||
|
||||
// 取消
|
||||
handleCancel() {
|
||||
wx.navigateBack();
|
||||
}
|
||||
});
|
3
miniprogram-example/pages/otp-add/index.json
Normal file
3
miniprogram-example/pages/otp-add/index.json
Normal file
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"usingComponents": {}
|
||||
}
|
119
miniprogram-example/pages/otp-add/index.wxml
Normal file
119
miniprogram-example/pages/otp-add/index.wxml
Normal file
|
@ -0,0 +1,119 @@
|
|||
<!-- otp-add/index.wxml -->
|
||||
<view class="container">
|
||||
<view class="header">
|
||||
<text class="title">添加OTP</text>
|
||||
</view>
|
||||
|
||||
<view class="form-container">
|
||||
<view class="form-group">
|
||||
<text class="form-label">名称 <text class="required">*</text></text>
|
||||
<input
|
||||
class="form-input"
|
||||
placeholder="请输入OTP名称"
|
||||
value="{{form.name}}"
|
||||
bindinput="handleInputChange"
|
||||
data-field="name"
|
||||
/>
|
||||
</view>
|
||||
|
||||
<view class="form-group">
|
||||
<text class="form-label">发行方</text>
|
||||
<input
|
||||
class="form-input"
|
||||
placeholder="请输入发行方名称"
|
||||
value="{{form.issuer}}"
|
||||
bindinput="handleInputChange"
|
||||
data-field="issuer"
|
||||
/>
|
||||
</view>
|
||||
|
||||
<view class="form-group">
|
||||
<text class="form-label">密钥 <text class="required">*</text></text>
|
||||
<view class="secret-input-container">
|
||||
<input
|
||||
class="form-input"
|
||||
placeholder="请输入密钥或扫描二维码"
|
||||
value="{{form.secret}}"
|
||||
bindinput="handleInputChange"
|
||||
data-field="secret"
|
||||
/>
|
||||
<view class="scan-button" bindtap="handleScanQRCode" wx:if="{{!scanMode}}">
|
||||
<text class="scan-icon">🔍</text>
|
||||
</view>
|
||||
<view class="scanning-indicator" wx:else>
|
||||
<view class="scanning-spinner"></view>
|
||||
</view>
|
||||
</view>
|
||||
</view>
|
||||
|
||||
<view class="form-group">
|
||||
<text class="form-label">算法</text>
|
||||
<picker
|
||||
mode="selector"
|
||||
range="{{algorithms}}"
|
||||
value="{{algorithms.indexOf(form.algorithm)}}"
|
||||
bindchange="handlePickerChange"
|
||||
data-field="algorithm"
|
||||
>
|
||||
<view class="picker-view">
|
||||
<text>{{form.algorithm}}</text>
|
||||
<text class="picker-arrow">▼</text>
|
||||
</view>
|
||||
</picker>
|
||||
</view>
|
||||
|
||||
<view class="form-row">
|
||||
<view class="form-group half">
|
||||
<text class="form-label">位数</text>
|
||||
<picker
|
||||
mode="selector"
|
||||
range="{{digitOptions}}"
|
||||
value="{{digitOptions.indexOf(form.digits)}}"
|
||||
bindchange="handlePickerChange"
|
||||
data-field="digits"
|
||||
>
|
||||
<view class="picker-view">
|
||||
<text>{{form.digits}}</text>
|
||||
<text class="picker-arrow">▼</text>
|
||||
</view>
|
||||
</picker>
|
||||
</view>
|
||||
|
||||
<view class="form-group half">
|
||||
<text class="form-label">周期(秒)</text>
|
||||
<picker
|
||||
mode="selector"
|
||||
range="{{periodOptions}}"
|
||||
value="{{periodOptions.indexOf(form.period)}}"
|
||||
bindchange="handlePickerChange"
|
||||
data-field="period"
|
||||
>
|
||||
<view class="picker-view">
|
||||
<text>{{form.period}}</text>
|
||||
<text class="picker-arrow">▼</text>
|
||||
</view>
|
||||
</picker>
|
||||
</view>
|
||||
</view>
|
||||
</view>
|
||||
|
||||
<view class="button-group">
|
||||
<button
|
||||
class="cancel-button"
|
||||
bindtap="handleCancel"
|
||||
disabled="{{submitting}}"
|
||||
>取消</button>
|
||||
|
||||
<button
|
||||
class="submit-button {{submitting ? 'loading' : ''}}"
|
||||
bindtap="handleSubmit"
|
||||
disabled="{{submitting}}"
|
||||
>
|
||||
<text wx:if="{{!submitting}}">保存</text>
|
||||
<view wx:else class="loading-container">
|
||||
<view class="loading-icon"></view>
|
||||
<text>保存中...</text>
|
||||
</view>
|
||||
</button>
|
||||
</view>
|
||||
</view>
|
176
miniprogram-example/pages/otp-add/index.wxss
Normal file
176
miniprogram-example/pages/otp-add/index.wxss
Normal file
|
@ -0,0 +1,176 @@
|
|||
/* otp-add/index.wxss */
|
||||
.container {
|
||||
min-height: 100vh;
|
||||
background-color: #f8f8f8;
|
||||
padding: 0 0 40rpx 0;
|
||||
}
|
||||
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 40rpx 32rpx;
|
||||
background-color: #ffffff;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 100;
|
||||
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 36rpx;
|
||||
font-weight: bold;
|
||||
color: #333333;
|
||||
}
|
||||
|
||||
.form-container {
|
||||
background-color: #ffffff;
|
||||
padding: 32rpx;
|
||||
margin: 32rpx;
|
||||
border-radius: 16rpx;
|
||||
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.form-group {
|
||||
margin-bottom: 32rpx;
|
||||
}
|
||||
|
||||
.form-row {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.form-group.half {
|
||||
width: 48%;
|
||||
}
|
||||
|
||||
.form-label {
|
||||
display: block;
|
||||
font-size: 28rpx;
|
||||
color: #666666;
|
||||
margin-bottom: 12rpx;
|
||||
}
|
||||
|
||||
.required {
|
||||
color: #ff4d4f;
|
||||
}
|
||||
|
||||
.form-input {
|
||||
width: 100%;
|
||||
height: 80rpx;
|
||||
background-color: #f5f5f5;
|
||||
border-radius: 8rpx;
|
||||
padding: 0 24rpx;
|
||||
font-size: 28rpx;
|
||||
color: #333333;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.secret-input-container {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.scan-button {
|
||||
position: absolute;
|
||||
right: 20rpx;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
width: 60rpx;
|
||||
height: 60rpx;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.scan-icon {
|
||||
font-size: 40rpx;
|
||||
color: #1890ff;
|
||||
}
|
||||
|
||||
.scanning-indicator {
|
||||
position: absolute;
|
||||
right: 20rpx;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
width: 40rpx;
|
||||
height: 40rpx;
|
||||
}
|
||||
|
||||
.scanning-spinner {
|
||||
width: 40rpx;
|
||||
height: 40rpx;
|
||||
border: 4rpx solid #f3f3f3;
|
||||
border-top: 4rpx solid #1890ff;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.picker-view {
|
||||
width: 100%;
|
||||
height: 80rpx;
|
||||
background-color: #f5f5f5;
|
||||
border-radius: 8rpx;
|
||||
padding: 0 24rpx;
|
||||
font-size: 28rpx;
|
||||
color: #333333;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.picker-arrow {
|
||||
font-size: 24rpx;
|
||||
color: #999999;
|
||||
}
|
||||
|
||||
.button-group {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
padding: 32rpx;
|
||||
}
|
||||
|
||||
.cancel-button, .submit-button {
|
||||
width: 48%;
|
||||
height: 88rpx;
|
||||
border-radius: 44rpx;
|
||||
font-size: 32rpx;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.cancel-button {
|
||||
background-color: #f5f5f5;
|
||||
color: #666666;
|
||||
}
|
||||
|
||||
.submit-button {
|
||||
background-color: #1890ff;
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.submit-button.loading {
|
||||
background-color: #8cc4ff;
|
||||
}
|
||||
|
||||
.loading-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.loading-icon {
|
||||
width: 36rpx;
|
||||
height: 36rpx;
|
||||
margin-right: 10rpx;
|
||||
border: 4rpx solid #ffffff;
|
||||
border-radius: 50%;
|
||||
border-top-color: transparent;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
213
miniprogram-example/pages/otp-list/index.js
Normal file
213
miniprogram-example/pages/otp-list/index.js
Normal file
|
@ -0,0 +1,213 @@
|
|||
// otp-list/index.js
|
||||
import { getOTPList, getOTPCode, deleteOTP } from '../../services/otp';
|
||||
import { checkLoginStatus } from '../../services/auth';
|
||||
|
||||
Page({
|
||||
data: {
|
||||
otpList: [],
|
||||
loading: true,
|
||||
refreshing: false
|
||||
},
|
||||
|
||||
onLoad() {
|
||||
this.checkLogin();
|
||||
},
|
||||
|
||||
onShow() {
|
||||
// 每次页面显示时刷新OTP列表
|
||||
if (!this.data.loading) {
|
||||
this.fetchOTPList();
|
||||
}
|
||||
},
|
||||
|
||||
// 下拉刷新
|
||||
onPullDownRefresh() {
|
||||
this.setData({ refreshing: true });
|
||||
this.fetchOTPList().finally(() => {
|
||||
wx.stopPullDownRefresh();
|
||||
this.setData({ refreshing: false });
|
||||
});
|
||||
},
|
||||
|
||||
// 检查登录状态
|
||||
checkLogin() {
|
||||
checkLoginStatus().then(isLoggedIn => {
|
||||
if (isLoggedIn) {
|
||||
this.fetchOTPList();
|
||||
} else {
|
||||
wx.redirectTo({
|
||||
url: '/pages/login/login'
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// 获取OTP列表
|
||||
fetchOTPList() {
|
||||
this.setData({ loading: true });
|
||||
return getOTPList()
|
||||
.then(res => {
|
||||
if (res.data && Array.isArray(res.data)) {
|
||||
this.setData({
|
||||
otpList: res.data,
|
||||
loading: false
|
||||
});
|
||||
|
||||
// 获取每个OTP的当前验证码
|
||||
this.refreshOTPCodes();
|
||||
}
|
||||
})
|
||||
.catch(err => {
|
||||
wx.showToast({
|
||||
title: '获取OTP列表失败',
|
||||
icon: 'none'
|
||||
});
|
||||
this.setData({ loading: false });
|
||||
});
|
||||
},
|
||||
|
||||
// 刷新所有OTP的验证码
|
||||
refreshOTPCodes() {
|
||||
const { otpList } = this.data;
|
||||
|
||||
// 为每个OTP获取当前验证码
|
||||
const promises = otpList.map(otp => {
|
||||
return getOTPCode(otp.id)
|
||||
.then(res => {
|
||||
if (res.data && res.data.code) {
|
||||
return {
|
||||
id: otp.id,
|
||||
code: res.data.code,
|
||||
expiresIn: res.data.expires_in || 30
|
||||
};
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.catch(() => null);
|
||||
});
|
||||
|
||||
Promise.all(promises).then(results => {
|
||||
const updatedList = [...this.data.otpList];
|
||||
|
||||
results.forEach(result => {
|
||||
if (result) {
|
||||
const index = updatedList.findIndex(otp => otp.id === result.id);
|
||||
if (index !== -1) {
|
||||
updatedList[index] = {
|
||||
...updatedList[index],
|
||||
currentCode: result.code,
|
||||
expiresIn: result.expiresIn
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
this.setData({ otpList: updatedList });
|
||||
|
||||
// 设置定时器,每秒更新倒计时
|
||||
this.startCountdown();
|
||||
});
|
||||
},
|
||||
|
||||
// 开始倒计时
|
||||
startCountdown() {
|
||||
// 清除之前的定时器
|
||||
if (this.countdownTimer) {
|
||||
clearInterval(this.countdownTimer);
|
||||
}
|
||||
|
||||
// 创建新的定时器,每秒更新一次
|
||||
this.countdownTimer = setInterval(() => {
|
||||
const { otpList } = this.data;
|
||||
let needRefresh = false;
|
||||
|
||||
const updatedList = otpList.map(otp => {
|
||||
if (!otp.countdown) {
|
||||
otp.countdown = otp.expiresIn || 30;
|
||||
}
|
||||
|
||||
otp.countdown -= 1;
|
||||
|
||||
// 如果倒计时结束,标记需要刷新
|
||||
if (otp.countdown <= 0) {
|
||||
needRefresh = true;
|
||||
}
|
||||
|
||||
return otp;
|
||||
});
|
||||
|
||||
this.setData({ otpList: updatedList });
|
||||
|
||||
// 如果有OTP需要刷新,重新获取验证码
|
||||
if (needRefresh) {
|
||||
this.refreshOTPCodes();
|
||||
}
|
||||
}, 1000);
|
||||
},
|
||||
|
||||
// 添加新的OTP
|
||||
handleAddOTP() {
|
||||
wx.navigateTo({
|
||||
url: '/pages/otp-add/index'
|
||||
});
|
||||
},
|
||||
|
||||
// 编辑OTP
|
||||
handleEditOTP(e) {
|
||||
const { id } = e.currentTarget.dataset;
|
||||
wx.navigateTo({
|
||||
url: `/pages/otp-edit/index?id=${id}`
|
||||
});
|
||||
},
|
||||
|
||||
// 删除OTP
|
||||
handleDeleteOTP(e) {
|
||||
const { id, name } = e.currentTarget.dataset;
|
||||
|
||||
wx.showModal({
|
||||
title: '确认删除',
|
||||
content: `确定要删除 ${name} 吗?`,
|
||||
confirmColor: '#ff4d4f',
|
||||
success: (res) => {
|
||||
if (res.confirm) {
|
||||
deleteOTP(id)
|
||||
.then(() => {
|
||||
wx.showToast({
|
||||
title: '删除成功',
|
||||
icon: 'success'
|
||||
});
|
||||
this.fetchOTPList();
|
||||
})
|
||||
.catch(err => {
|
||||
wx.showToast({
|
||||
title: '删除失败',
|
||||
icon: 'none'
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// 复制验证码
|
||||
handleCopyCode(e) {
|
||||
const { code } = e.currentTarget.dataset;
|
||||
|
||||
wx.setClipboardData({
|
||||
data: code,
|
||||
success: () => {
|
||||
wx.showToast({
|
||||
title: '验证码已复制',
|
||||
icon: 'success'
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
onUnload() {
|
||||
// 页面卸载时清除定时器
|
||||
if (this.countdownTimer) {
|
||||
clearInterval(this.countdownTimer);
|
||||
}
|
||||
}
|
||||
});
|
3
miniprogram-example/pages/otp-list/index.json
Normal file
3
miniprogram-example/pages/otp-list/index.json
Normal file
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"usingComponents": {}
|
||||
}
|
59
miniprogram-example/pages/otp-list/index.wxml
Normal file
59
miniprogram-example/pages/otp-list/index.wxml
Normal file
|
@ -0,0 +1,59 @@
|
|||
<!-- otp-list/index.wxml -->
|
||||
<view class="container">
|
||||
<view class="header">
|
||||
<text class="title">我的OTP列表</text>
|
||||
<view class="add-button" bindtap="handleAddOTP">
|
||||
<text class="add-icon">+</text>
|
||||
</view>
|
||||
</view>
|
||||
|
||||
<!-- 加载中 -->
|
||||
<view class="loading-container" wx:if="{{loading}}">
|
||||
<view class="loading-spinner"></view>
|
||||
<text class="loading-text">加载中...</text>
|
||||
</view>
|
||||
|
||||
<!-- OTP列表 -->
|
||||
<view class="otp-list" wx:else>
|
||||
<block wx:if="{{otpList.length > 0}}">
|
||||
<view class="otp-item" wx:for="{{otpList}}" wx:key="id">
|
||||
<view class="otp-info">
|
||||
<view class="otp-name-row">
|
||||
<text class="otp-name">{{item.name}}</text>
|
||||
<text class="otp-issuer">{{item.issuer}}</text>
|
||||
</view>
|
||||
|
||||
<view class="otp-code-row" bindtap="handleCopyCode" data-code="{{item.currentCode}}">
|
||||
<text class="otp-code">{{item.currentCode || '******'}}</text>
|
||||
<text class="copy-hint">点击复制</text>
|
||||
</view>
|
||||
|
||||
<view class="otp-countdown">
|
||||
<progress
|
||||
percent="{{(item.countdown / item.expiresIn) * 100}}"
|
||||
stroke-width="3"
|
||||
activeColor="{{item.countdown < 10 ? '#ff4d4f' : '#1890ff'}}"
|
||||
backgroundColor="#e9e9e9"
|
||||
/>
|
||||
<text class="countdown-text">{{item.countdown || 0}}s</text>
|
||||
</view>
|
||||
</view>
|
||||
|
||||
<view class="otp-actions">
|
||||
<view class="action-button edit" bindtap="handleEditOTP" data-id="{{item.id}}">
|
||||
<text class="action-icon">✎</text>
|
||||
</view>
|
||||
<view class="action-button delete" bindtap="handleDeleteOTP" data-id="{{item.id}}" data-name="{{item.name}}">
|
||||
<text class="action-icon">✕</text>
|
||||
</view>
|
||||
</view>
|
||||
</view>
|
||||
</block>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<view class="empty-state" wx:else>
|
||||
<image class="empty-image" src="/assets/images/empty.png" mode="aspectFit"></image>
|
||||
<text class="empty-text">暂无OTP,点击右上角添加</text>
|
||||
</view>
|
||||
</view>
|
||||
</view>
|
201
miniprogram-example/pages/otp-list/index.wxss
Normal file
201
miniprogram-example/pages/otp-list/index.wxss
Normal file
|
@ -0,0 +1,201 @@
|
|||
/* otp-list/index.wxss */
|
||||
.container {
|
||||
min-height: 100vh;
|
||||
background-color: #f8f8f8;
|
||||
padding: 0 0 40rpx 0;
|
||||
}
|
||||
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 40rpx 32rpx;
|
||||
background-color: #ffffff;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 100;
|
||||
box-shadow: 0 2rpx 10rpx rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 36rpx;
|
||||
font-weight: bold;
|
||||
color: #333333;
|
||||
}
|
||||
|
||||
.add-button {
|
||||
width: 64rpx;
|
||||
height: 64rpx;
|
||||
border-radius: 32rpx;
|
||||
background-color: #1890ff;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
box-shadow: 0 4rpx 12rpx rgba(24, 144, 255, 0.3);
|
||||
}
|
||||
|
||||
.add-icon {
|
||||
color: #ffffff;
|
||||
font-size: 40rpx;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
/* 加载状态 */
|
||||
.loading-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 120rpx 0;
|
||||
}
|
||||
|
||||
.loading-spinner {
|
||||
width: 64rpx;
|
||||
height: 64rpx;
|
||||
border: 6rpx solid #f3f3f3;
|
||||
border-top: 6rpx solid #1890ff;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
.loading-text {
|
||||
margin-top: 20rpx;
|
||||
font-size: 28rpx;
|
||||
color: #999999;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
/* OTP列表 */
|
||||
.otp-list {
|
||||
padding: 20rpx 32rpx;
|
||||
}
|
||||
|
||||
.otp-item {
|
||||
background-color: #ffffff;
|
||||
border-radius: 16rpx;
|
||||
padding: 32rpx;
|
||||
margin-bottom: 20rpx;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
box-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.otp-info {
|
||||
flex: 1;
|
||||
margin-right: 20rpx;
|
||||
}
|
||||
|
||||
.otp-name-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 16rpx;
|
||||
}
|
||||
|
||||
.otp-name {
|
||||
font-size: 32rpx;
|
||||
font-weight: bold;
|
||||
color: #333333;
|
||||
margin-right: 16rpx;
|
||||
}
|
||||
|
||||
.otp-issuer {
|
||||
font-size: 24rpx;
|
||||
color: #666666;
|
||||
background-color: #f5f5f5;
|
||||
padding: 4rpx 12rpx;
|
||||
border-radius: 8rpx;
|
||||
}
|
||||
|
||||
.otp-code-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 20rpx;
|
||||
}
|
||||
|
||||
.otp-code {
|
||||
font-size: 44rpx;
|
||||
font-family: monospace;
|
||||
font-weight: bold;
|
||||
color: #1890ff;
|
||||
letter-spacing: 4rpx;
|
||||
margin-right: 16rpx;
|
||||
}
|
||||
|
||||
.copy-hint {
|
||||
font-size: 24rpx;
|
||||
color: #999999;
|
||||
}
|
||||
|
||||
.otp-countdown {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.countdown-text {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: -30rpx;
|
||||
font-size: 24rpx;
|
||||
color: #999999;
|
||||
}
|
||||
|
||||
/* OTP操作按钮 */
|
||||
.otp-actions {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.action-button {
|
||||
width: 56rpx;
|
||||
height: 56rpx;
|
||||
border-radius: 28rpx;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-bottom: 16rpx;
|
||||
}
|
||||
|
||||
.action-button.edit {
|
||||
background-color: #f0f7ff;
|
||||
}
|
||||
|
||||
.action-button.delete {
|
||||
background-color: #fff1f0;
|
||||
}
|
||||
|
||||
.action-icon {
|
||||
font-size: 32rpx;
|
||||
}
|
||||
|
||||
.edit .action-icon {
|
||||
color: #1890ff;
|
||||
}
|
||||
|
||||
.delete .action-icon {
|
||||
color: #ff4d4f;
|
||||
}
|
||||
|
||||
/* 空状态 */
|
||||
.empty-state {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 120rpx 0;
|
||||
}
|
||||
|
||||
.empty-image {
|
||||
width: 240rpx;
|
||||
height: 240rpx;
|
||||
margin-bottom: 32rpx;
|
||||
}
|
||||
|
||||
.empty-text {
|
||||
font-size: 28rpx;
|
||||
color: #999999;
|
||||
}
|
47
miniprogram-example/project.config.json
Normal file
47
miniprogram-example/project.config.json
Normal file
|
@ -0,0 +1,47 @@
|
|||
{
|
||||
"description": "项目配置文件",
|
||||
"packOptions": {
|
||||
"ignore": [],
|
||||
"include": []
|
||||
},
|
||||
"miniprogramRoot": "",
|
||||
"compileType": "miniprogram",
|
||||
"projectname": "OTPM",
|
||||
"setting": {
|
||||
"useCompilerPlugins": [
|
||||
"sass"
|
||||
],
|
||||
"babelSetting": {
|
||||
"ignore": [],
|
||||
"disablePlugins": [],
|
||||
"outputPath": ""
|
||||
},
|
||||
"es6": true,
|
||||
"enhance": true,
|
||||
"minified": true,
|
||||
"postcss": true,
|
||||
"minifyWXSS": true,
|
||||
"minifyWXML": true,
|
||||
"uglifyFileName": true,
|
||||
"packNpmManually": false,
|
||||
"packNpmRelationList": [],
|
||||
"ignoreUploadUnusedFiles": true,
|
||||
"compileWorklet": false,
|
||||
"uploadWithSourceMap": true,
|
||||
"localPlugins": false,
|
||||
"disableUseStrict": false,
|
||||
"condition": false,
|
||||
"swc": false,
|
||||
"disableSWC": true
|
||||
},
|
||||
"simulatorType": "wechat",
|
||||
"simulatorPluginLibVersion": {},
|
||||
"condition": {},
|
||||
"srcMiniprogramRoot": "",
|
||||
"appid": "wxb6599459668b6b55",
|
||||
"libVersion": "2.30.2",
|
||||
"editorSetting": {
|
||||
"tabIndent": "insertSpaces",
|
||||
"tabSize": 2
|
||||
}
|
||||
}
|
23
miniprogram-example/project.private.config.json
Normal file
23
miniprogram-example/project.private.config.json
Normal file
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"libVersion": "3.8.5",
|
||||
"projectname": "OTPM",
|
||||
"condition": {},
|
||||
"setting": {
|
||||
"urlCheck": true,
|
||||
"coverView": false,
|
||||
"lazyloadPlaceholderEnable": false,
|
||||
"skylineRenderEnable": false,
|
||||
"preloadBackgroundData": false,
|
||||
"autoAudits": false,
|
||||
"useApiHook": true,
|
||||
"useApiHostProcess": true,
|
||||
"showShadowRootInWxmlPanel": false,
|
||||
"useStaticServer": false,
|
||||
"useLanDebug": false,
|
||||
"showES6CompileOption": false,
|
||||
"compileHotReLoad": true,
|
||||
"checkInvalidKey": true,
|
||||
"ignoreDevUnusedFiles": true,
|
||||
"bigPackageSizeSupport": false
|
||||
}
|
||||
}
|
84
miniprogram-example/services/auth.js
Normal file
84
miniprogram-example/services/auth.js
Normal file
|
@ -0,0 +1,84 @@
|
|||
// auth.js - 认证相关服务
|
||||
|
||||
import request from '../utils/request';
|
||||
|
||||
/**
|
||||
* 微信登录
|
||||
* 1. 调用wx.login获取code
|
||||
* 2. 发送code到服务端换取token和openid
|
||||
* 3. 保存token和openid到本地存储
|
||||
*/
|
||||
export const wxLogin = () => {
|
||||
return new Promise((resolve, reject) => {
|
||||
wx.login({
|
||||
success: (res) => {
|
||||
if (res.code) {
|
||||
// 发送code到服务端
|
||||
request({
|
||||
url: '/login',
|
||||
method: 'POST',
|
||||
data: {
|
||||
code: res.code
|
||||
}
|
||||
}).then(response => {
|
||||
// 保存token和openid
|
||||
if (response.data && response.data.token && response.data.openid) {
|
||||
wx.setStorageSync('token', response.data.token);
|
||||
wx.setStorageSync('openid', response.data.openid);
|
||||
resolve(response.data);
|
||||
} else {
|
||||
reject(new Error('登录失败,服务器返回数据格式错误'));
|
||||
}
|
||||
}).catch(err => {
|
||||
reject(err);
|
||||
});
|
||||
} else {
|
||||
reject(new Error('登录失败,获取code失败: ' + res.errMsg));
|
||||
}
|
||||
},
|
||||
fail: (err) => {
|
||||
reject(new Error('微信登录失败: ' + err.errMsg));
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 检查登录状态
|
||||
* 1. 检查本地是否有token和openid
|
||||
* 2. 如果有,验证token是否有效
|
||||
* 3. 如果无效,清除本地存储并返回false
|
||||
*/
|
||||
export const checkLoginStatus = () => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const token = wx.getStorageSync('token');
|
||||
const openid = wx.getStorageSync('openid');
|
||||
|
||||
if (!token || !openid) {
|
||||
resolve(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// 验证token有效性
|
||||
request({
|
||||
url: '/verify-token',
|
||||
method: 'POST'
|
||||
}).then(() => {
|
||||
resolve(true);
|
||||
}).catch(() => {
|
||||
// token无效,清除本地存储
|
||||
wx.removeStorageSync('token');
|
||||
wx.removeStorageSync('openid');
|
||||
resolve(false);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 退出登录
|
||||
*/
|
||||
export const logout = () => {
|
||||
wx.removeStorageSync('token');
|
||||
wx.removeStorageSync('openid');
|
||||
return Promise.resolve();
|
||||
};
|
119
miniprogram-example/services/otp.js
Normal file
119
miniprogram-example/services/otp.js
Normal file
|
@ -0,0 +1,119 @@
|
|||
// otp.js - OTP相关服务
|
||||
|
||||
import request from '../utils/request';
|
||||
|
||||
/**
|
||||
* 创建新的OTP
|
||||
* @param {Object} params - 创建OTP的参数
|
||||
* @param {string} params.name - OTP名称
|
||||
* @param {string} params.issuer - 发行方
|
||||
* @param {string} params.secret - 密钥
|
||||
* @param {string} params.algorithm - 算法,默认为SHA1
|
||||
* @param {number} params.digits - 位数,默认为6
|
||||
* @param {number} params.period - 周期,默认为30秒
|
||||
* @returns {Promise} - 返回创建结果
|
||||
*/
|
||||
export const createOTP = (params) => {
|
||||
if (!params || !params.secret) {
|
||||
return Promise.reject(new Error('缺少必要的参数: secret'));
|
||||
}
|
||||
return request({
|
||||
url: '/otp',
|
||||
method: 'POST',
|
||||
data: {
|
||||
name: params.name || '',
|
||||
issuer: params.issuer || '',
|
||||
secret: params.secret,
|
||||
algorithm: params.algorithm || 'SHA1',
|
||||
digits: params.digits || 6,
|
||||
period: params.period || 30
|
||||
}
|
||||
}).catch(err => {
|
||||
console.error('创建OTP失败:', err);
|
||||
throw new Error('创建OTP失败: ' + (err.message || '未知错误'));
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取用户所有OTP列表
|
||||
* @returns {Promise} - 返回OTP列表
|
||||
*/
|
||||
export const getOTPList = () => {
|
||||
return request({
|
||||
url: '/otp',
|
||||
method: 'GET'
|
||||
}).catch(err => {
|
||||
console.error('获取OTP列表失败:', err);
|
||||
throw new Error('获取OTP列表失败: ' + (err.message || '未知错误'));
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取指定OTP的当前验证码
|
||||
* @param {string} id - OTP的ID
|
||||
* @returns {Promise} - 返回当前验证码
|
||||
*/
|
||||
export const getOTPCode = (id) => {
|
||||
if (!id) {
|
||||
return Promise.reject(new Error('缺少必要的参数: id'));
|
||||
}
|
||||
return request({
|
||||
url: `/otp/${id}/code`,
|
||||
method: 'GET'
|
||||
}).catch(err => {
|
||||
console.error('获取OTP代码失败:', err);
|
||||
throw new Error('获取OTP代码失败: ' + (err.message || '未知错误'));
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 验证OTP
|
||||
* @param {string} id - OTP的ID
|
||||
* @param {string} code - 用户输入的验证码
|
||||
* @returns {Promise} - 返回验证结果
|
||||
*/
|
||||
export const verifyOTP = (id, code) => {
|
||||
if (!id || !code) {
|
||||
return Promise.reject(new Error('缺少必要的参数: id或code'));
|
||||
}
|
||||
return request({
|
||||
url: `/otp/${id}/verify`,
|
||||
method: 'POST',
|
||||
data: { code }
|
||||
}).catch(err => {
|
||||
console.error('验证OTP失败:', err);
|
||||
throw new Error('验证OTP失败: ' + (err.message || '未知错误'));
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 更新OTP信息
|
||||
* @param {string} id - OTP的ID
|
||||
* @param {Object} params - 更新的参数
|
||||
* @returns {Promise} - 返回更新结果
|
||||
*/
|
||||
export const updateOTP = (id, params) => {
|
||||
if (!id || !params) {
|
||||
return Promise.reject(new Error('缺少必要的参数: id或params'));
|
||||
}
|
||||
return request({
|
||||
url: `/otp/${id}`,
|
||||
method: 'PUT',
|
||||
data: params
|
||||
}).catch(err => {
|
||||
console.error('更新OTP失败:', err);
|
||||
throw new Error('更新OTP失败: ' + (err.message || '未知错误'));
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 删除OTP
|
||||
* @param {string} id - OTP的ID
|
||||
* @returns {Promise} - 返回删除结果
|
||||
*/
|
||||
export const deleteOTP = (id) => {
|
||||
return request({
|
||||
url: `/otp/${id}`,
|
||||
method: 'DELETE'
|
||||
});
|
||||
};
|
58
miniprogram-example/utils/request.js
Normal file
58
miniprogram-example/utils/request.js
Normal file
|
@ -0,0 +1,58 @@
|
|||
// request.js - 网络请求工具类
|
||||
|
||||
const BASE_URL = 'https://otpm.zeroc.net'; // 替换为实际的API域名
|
||||
|
||||
// 请求拦截器
|
||||
const request = (options) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const token = wx.getStorageSync('token');
|
||||
const header = {
|
||||
'Content-Type': 'application/json',
|
||||
...options.header
|
||||
};
|
||||
|
||||
// 如果有token,添加到请求头
|
||||
if (token) {
|
||||
header['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
wx.request({
|
||||
url: `${BASE_URL}${options.url}`,
|
||||
method: options.method || 'GET',
|
||||
data: options.data,
|
||||
header: header,
|
||||
success: (res) => {
|
||||
// 处理业务错误
|
||||
if (res.data.code !== 0) {
|
||||
// token过期,直接清除并跳转登录
|
||||
if (res.statusCode === 401) {
|
||||
wx.removeStorageSync('token');
|
||||
wx.removeStorageSync('openid');
|
||||
reject(new Error('登录已过期,请重新登录'));
|
||||
return;
|
||||
}
|
||||
reject(new Error(res.data.message || '请求失败'));
|
||||
return;
|
||||
}
|
||||
resolve(res.data);
|
||||
},
|
||||
fail: reject
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// 刷新token
|
||||
const refreshToken = () => {
|
||||
return request({
|
||||
url: '/refresh-token',
|
||||
method: 'POST'
|
||||
}).then(res => {
|
||||
if (res.data && res.data.token) {
|
||||
wx.setStorageSync('token', res.data.token);
|
||||
return res.data.token;
|
||||
}
|
||||
throw new Error('Failed to refresh token');
|
||||
});
|
||||
};
|
||||
|
||||
export default request;
|
195
models/otp.go
Normal file
195
models/otp.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// OTP represents a TOTP configuration
|
||||
type OTP struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
UserID string `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Issuer string `db:"issuer" json:"issuer"`
|
||||
Secret string `db:"secret" json:"-"` // Never expose secret in JSON
|
||||
Algorithm string `db:"algorithm" json:"algorithm"`
|
||||
Digits int `db:"digits" json:"digits"`
|
||||
Period int `db:"period" json:"period"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// OTPParams represents common OTP parameters used in creation and update
|
||||
type OTPParams struct {
|
||||
Name string
|
||||
Issuer string
|
||||
Secret string
|
||||
Algorithm string
|
||||
Digits int
|
||||
Period int
|
||||
}
|
||||
|
||||
// OTPRepository handles OTP data operations
|
||||
type OTPRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewOTPRepository creates a new OTPRepository
|
||||
func NewOTPRepository(db *sqlx.DB) *OTPRepository {
|
||||
return &OTPRepository{db: db}
|
||||
}
|
||||
|
||||
// FindByID finds an OTP by ID and user ID
|
||||
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
|
||||
var otp OTP
|
||||
query := `SELECT * FROM otps WHERE id = ? AND user_id = ?`
|
||||
err := r.db.GetContext(ctx, &otp, query, id, userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("otp not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find otp: %w", err)
|
||||
}
|
||||
return &otp, nil
|
||||
}
|
||||
|
||||
// FindAllByUserID finds all OTPs for a user
|
||||
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
|
||||
var otps []*OTP
|
||||
query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC`
|
||||
err := r.db.SelectContext(ctx, &otps, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find otps: %w", err)
|
||||
}
|
||||
return otps, nil
|
||||
}
|
||||
|
||||
// Create creates a new OTP
|
||||
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
|
||||
query := `
|
||||
INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
now := time.Now()
|
||||
otp.CreatedAt = now
|
||||
otp.UpdatedAt = now
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
otp.ID,
|
||||
otp.UserID,
|
||||
otp.Name,
|
||||
otp.Issuer,
|
||||
otp.Secret,
|
||||
otp.Algorithm,
|
||||
otp.Digits,
|
||||
otp.Period,
|
||||
otp.CreatedAt,
|
||||
otp.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create otp: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates an existing OTP
|
||||
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
|
||||
query := `
|
||||
UPDATE otps
|
||||
SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ?
|
||||
WHERE id = ? AND user_id = ?
|
||||
`
|
||||
otp.UpdatedAt = time.Now()
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
otp.Name,
|
||||
otp.Issuer,
|
||||
otp.Algorithm,
|
||||
otp.Digits,
|
||||
otp.Period,
|
||||
otp.UpdatedAt,
|
||||
otp.ID,
|
||||
otp.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update otp: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("otp not found or not owned by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an OTP
|
||||
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
query := `DELETE FROM otps WHERE id = ? AND user_id = ?`
|
||||
result, err := r.db.ExecContext(ctx, query, id, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete otp: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("otp not found or not owned by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountByUserID counts the number of OTPs for a user
|
||||
func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) {
|
||||
var count int
|
||||
query := `SELECT COUNT(*) FROM otps WHERE user_id = ?`
|
||||
err := r.db.GetContext(ctx, &count, query, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count otps: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Transaction executes a function within a transaction
|
||||
func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
|
||||
tx, err := r.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
114
models/user.go
Normal file
114
models/user.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
OpenID string `db:"openid" json:"openid"`
|
||||
SessionKey string `db:"session_key" json:"-"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// UserRepository handles user data operations
|
||||
type UserRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new UserRepository
|
||||
func NewUserRepository(db *sqlx.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// FindByID finds a user by ID
|
||||
func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
|
||||
var user User
|
||||
query := `SELECT * FROM users WHERE id = ?`
|
||||
err := r.db.GetContext(ctx, &user, query, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindByOpenID finds a user by OpenID
|
||||
func (r *UserRepository) FindByOpenID(ctx context.Context, openID string) (*User, error) {
|
||||
var user User
|
||||
query := `SELECT * FROM users WHERE openid = ?`
|
||||
err := r.db.GetContext(ctx, &user, query, openID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil // User not found, but not an error
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Create creates a new user
|
||||
func (r *UserRepository) Create(ctx context.Context, user *User) error {
|
||||
query := `
|
||||
INSERT INTO users (id, openid, session_key, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`
|
||||
now := time.Now()
|
||||
user.CreatedAt = now
|
||||
user.UpdatedAt = now
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
user.ID,
|
||||
user.OpenID,
|
||||
user.SessionKey,
|
||||
user.CreatedAt,
|
||||
user.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates an existing user
|
||||
func (r *UserRepository) Update(ctx context.Context, user *User) error {
|
||||
query := `
|
||||
UPDATE users
|
||||
SET session_key = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
user.SessionKey,
|
||||
user.UpdatedAt,
|
||||
user.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes a user
|
||||
func (r *UserRepository) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM users WHERE id = ?`
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
332
security/security.go
Normal file
332
security/security.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// SecurityService provides security functionality
|
||||
type SecurityService struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
// Config represents security configuration
|
||||
type Config struct {
|
||||
// CSRF protection
|
||||
CSRFTokenLength int
|
||||
CSRFTokenExpiry time.Duration
|
||||
CSRFCookieName string
|
||||
CSRFHeaderName string
|
||||
CSRFCookieSecure bool
|
||||
CSRFCookieHTTPOnly bool
|
||||
CSRFCookieSameSite http.SameSite
|
||||
|
||||
// Rate limiting
|
||||
RateLimitRequests int
|
||||
RateLimitWindow time.Duration
|
||||
|
||||
// Password hashing
|
||||
Argon2Time uint32
|
||||
Argon2Memory uint32
|
||||
Argon2Threads uint8
|
||||
Argon2KeyLen uint32
|
||||
Argon2SaltLen uint32
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default security configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
// CSRF protection
|
||||
CSRFTokenLength: 32,
|
||||
CSRFTokenExpiry: 24 * time.Hour,
|
||||
CSRFCookieName: "csrf_token",
|
||||
CSRFHeaderName: "X-CSRF-Token",
|
||||
CSRFCookieSecure: true,
|
||||
CSRFCookieHTTPOnly: true,
|
||||
CSRFCookieSameSite: http.SameSiteStrictMode,
|
||||
|
||||
// Rate limiting
|
||||
RateLimitRequests: 100,
|
||||
RateLimitWindow: time.Minute,
|
||||
|
||||
// Password hashing
|
||||
Argon2Time: 1,
|
||||
Argon2Memory: 64 * 1024,
|
||||
Argon2Threads: 4,
|
||||
Argon2KeyLen: 32,
|
||||
Argon2SaltLen: 16,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSecurityService creates a new SecurityService
|
||||
func NewSecurityService(config *Config) *SecurityService {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
return &SecurityService{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a CSRF token
|
||||
func (s *SecurityService) GenerateCSRFToken() (string, error) {
|
||||
// Generate random bytes
|
||||
bytes := make([]byte, s.config.CSRFTokenLength)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Encode as base64
|
||||
token := base64.StdEncoding.EncodeToString(bytes)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SetCSRFCookie sets a CSRF cookie
|
||||
func (s *SecurityService) SetCSRFCookie(w http.ResponseWriter, token string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: s.config.CSRFCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
Expires: time.Now().Add(s.config.CSRFTokenExpiry),
|
||||
Secure: s.config.CSRFCookieSecure,
|
||||
HttpOnly: s.config.CSRFCookieHTTPOnly,
|
||||
SameSite: s.config.CSRFCookieSameSite,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateCSRFToken validates a CSRF token
|
||||
func (s *SecurityService) ValidateCSRFToken(r *http.Request) bool {
|
||||
// Get token from cookie
|
||||
cookie, err := r.Cookie(s.config.CSRFCookieName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
cookieToken := cookie.Value
|
||||
|
||||
// Get token from header
|
||||
headerToken := r.Header.Get(s.config.CSRFHeaderName)
|
||||
|
||||
// Compare tokens
|
||||
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(headerToken)) == 1
|
||||
}
|
||||
|
||||
// CSRFMiddleware creates a middleware that validates CSRF tokens
|
||||
func (s *SecurityService) CSRFMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip for GET, HEAD, OPTIONS, TRACE
|
||||
if r.Method == http.MethodGet ||
|
||||
r.Method == http.MethodHead ||
|
||||
r.Method == http.MethodOptions ||
|
||||
r.Method == http.MethodTrace {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF token
|
||||
if !s.ValidateCSRFToken(r) {
|
||||
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// RateLimiter represents a rate limiter
|
||||
type RateLimiter struct {
|
||||
requests map[string][]time.Time
|
||||
config *Config
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new RateLimiter
|
||||
func NewRateLimiter(config *Config) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
requests: make(map[string][]time.Time),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed
|
||||
func (r *RateLimiter) Allow(key string) bool {
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.RateLimitWindow)
|
||||
|
||||
// Get requests for key
|
||||
requests := r.requests[key]
|
||||
|
||||
// Filter out old requests
|
||||
var newRequests []time.Time
|
||||
for _, t := range requests {
|
||||
if t.After(windowStart) {
|
||||
newRequests = append(newRequests, t)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if rate limit is exceeded
|
||||
if len(newRequests) >= r.config.RateLimitRequests {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add current request
|
||||
newRequests = append(newRequests, now)
|
||||
r.requests[key] = newRequests
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RateLimitMiddleware creates a middleware that limits request rate
|
||||
func (s *SecurityService) RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get client IP
|
||||
ip := getClientIP(r)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow(ip) {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getClientIP gets the client IP address
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header
|
||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||
if xForwardedFor != "" {
|
||||
// X-Forwarded-For can contain multiple IPs, use the first one
|
||||
ips := strings.Split(xForwardedFor, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
xRealIP := r.Header.Get("X-Real-IP")
|
||||
if xRealIP != "" {
|
||||
return xRealIP
|
||||
}
|
||||
|
||||
// Use RemoteAddr
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// HashPassword hashes a password using Argon2
|
||||
func (s *SecurityService) HashPassword(password string) (string, error) {
|
||||
// Generate salt
|
||||
salt := make([]byte, s.config.Argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
s.config.Argon2Time,
|
||||
s.config.Argon2Memory,
|
||||
s.config.Argon2Threads,
|
||||
s.config.Argon2KeyLen,
|
||||
)
|
||||
|
||||
// Encode as base64
|
||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
||||
hashBase64 := base64.StdEncoding.EncodeToString(hash)
|
||||
|
||||
// Format as $argon2id$v=19$m=65536,t=1,p=4$<salt>$<hash>
|
||||
return fmt.Sprintf(
|
||||
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||
s.config.Argon2Memory,
|
||||
s.config.Argon2Time,
|
||||
s.config.Argon2Threads,
|
||||
saltBase64,
|
||||
hashBase64,
|
||||
), nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a password against a hash
|
||||
func (s *SecurityService) VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
// Parse encoded hash
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return false, fmt.Errorf("invalid hash format")
|
||||
}
|
||||
|
||||
// Extract parameters
|
||||
if parts[1] != "argon2id" {
|
||||
return false, fmt.Errorf("unsupported hash algorithm")
|
||||
}
|
||||
|
||||
var memory uint32
|
||||
var time uint32
|
||||
var threads uint8
|
||||
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to parse hash parameters: %w", err)
|
||||
}
|
||||
|
||||
// Decode salt and hash
|
||||
salt, err := base64.StdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decode salt: %w", err)
|
||||
}
|
||||
|
||||
hash, err := base64.StdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decode hash: %w", err)
|
||||
}
|
||||
|
||||
// Hash password with same parameters
|
||||
newHash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
time,
|
||||
memory,
|
||||
threads,
|
||||
uint32(len(hash)),
|
||||
)
|
||||
|
||||
// Compare hashes
|
||||
return subtle.ConstantTimeCompare(hash, newHash) == 1, nil
|
||||
}
|
||||
|
||||
// SecureHeadersMiddleware adds security headers to responses
|
||||
func (s *SecurityService) SecureHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// contextKey is a type for context keys
|
||||
type contextKey string
|
||||
|
||||
// userIDKey is the key for user ID in context
|
||||
const userIDKey = contextKey("user_id")
|
||||
|
||||
// WithUserID adds a user ID to the context
|
||||
func WithUserID(ctx context.Context, userID string) context.Context {
|
||||
return context.WithValue(ctx, userIDKey, userID)
|
||||
}
|
||||
|
||||
// GetUserID gets the user ID from the context
|
||||
func GetUserID(ctx context.Context) (string, bool) {
|
||||
userID, ok := ctx.Value(userIDKey).(string)
|
||||
return userID, ok
|
||||
}
|
172
server/server.go
Normal file
172
server/server.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"otpm/config"
|
||||
"otpm/middleware"
|
||||
)
|
||||
|
||||
// Server represents the HTTP server
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
router *http.ServeMux
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// New creates a new server
|
||||
func New(cfg *config.Config) *Server {
|
||||
router := http.NewServeMux()
|
||||
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
Handler: router,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
return &Server{
|
||||
server: server,
|
||||
router: router,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start() error {
|
||||
// Apply global middleware in correct order with enhanced error handling
|
||||
var handler http.Handler = s.router
|
||||
|
||||
// Logger should be first to capture all request details
|
||||
handler = middleware.Logger(handler)
|
||||
|
||||
// CORS next to handle pre-flight requests
|
||||
handler = middleware.CORS(handler)
|
||||
|
||||
// Then Timeout to enforce request deadlines
|
||||
handler = middleware.Timeout(s.config.Server.Timeout)(handler)
|
||||
|
||||
// Recover should be outermost to catch any panics
|
||||
handler = middleware.Recover(handler)
|
||||
|
||||
s.server.Handler = handler
|
||||
|
||||
// Log server configuration at startup
|
||||
log.Printf("Server configuration:\n"+
|
||||
"Address: %s\n"+
|
||||
"Read Timeout: %v\n"+
|
||||
"Write Timeout: %v\n"+
|
||||
"Idle Timeout: %v\n"+
|
||||
"Request Timeout: %v",
|
||||
s.server.Addr,
|
||||
s.server.ReadTimeout,
|
||||
s.server.WriteTimeout,
|
||||
s.server.IdleTimeout,
|
||||
s.config.Server.Timeout,
|
||||
)
|
||||
|
||||
// Start server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
log.Printf("Server starting on %s", s.server.Addr)
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
serverErr <- fmt.Errorf("server error: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal or server error
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case err := <-serverErr:
|
||||
return err
|
||||
case <-quit:
|
||||
return s.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the server
|
||||
func (s *Server) Shutdown() error {
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.config.Server.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("graceful shutdown failed: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Server stopped gracefully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Router returns the router
|
||||
func (s *Server) Router() *http.ServeMux {
|
||||
return s.router
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all routes
|
||||
func (s *Server) RegisterRoutes(routes map[string]http.Handler) {
|
||||
for pattern, handler := range routes {
|
||||
s.router.Handle(pattern, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAuthRoutes registers routes that require authentication
|
||||
func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) {
|
||||
for pattern, handler := range routes {
|
||||
// Apply authentication middleware
|
||||
authHandler := middleware.Auth(s.config.JWT.Secret)(handler)
|
||||
s.router.Handle(pattern, authHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHealthCheck registers an enhanced health check endpoint
|
||||
func (s *Server) RegisterHealthCheck() {
|
||||
s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
"version": "1.0.0", // Hardcoded version instead of from config
|
||||
"system": map[string]interface{}{
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"memory": getMemoryUsage(),
|
||||
},
|
||||
}
|
||||
|
||||
// Add database status if configured
|
||||
if s.config.Database.DSN != "" { // Changed from URL to DSN to match config
|
||||
dbStatus := "ok"
|
||||
// Removed DB ping check since we don't have DB instance in config
|
||||
response["database"] = dbStatus
|
||||
}
|
||||
|
||||
middleware.SuccessResponse(w, response)
|
||||
})
|
||||
}
|
||||
|
||||
// getMemoryUsage returns current memory usage in MB
|
||||
func getMemoryUsage() map[string]interface{} {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
return map[string]interface{}{
|
||||
"alloc_mb": bToMb(m.Alloc),
|
||||
"total_alloc_mb": bToMb(m.TotalAlloc),
|
||||
"sys_mb": bToMb(m.Sys),
|
||||
"num_gc": m.NumGC,
|
||||
}
|
||||
}
|
||||
|
||||
func bToMb(b uint64) float64 {
|
||||
return float64(b) / 1024 / 1024
|
||||
}
|
230
services/auth.go
Normal file
230
services/auth.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"otpm/config"
|
||||
"otpm/models"
|
||||
)
|
||||
|
||||
// WeChatCode2SessionResponse represents the response from WeChat code2session API
|
||||
type WeChatCode2SessionResponse struct {
|
||||
OpenID string `json:"openid"`
|
||||
SessionKey string `json:"session_key"`
|
||||
UnionID string `json:"unionid"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
// AuthService handles authentication related operations
|
||||
type AuthService struct {
|
||||
config *config.Config
|
||||
userRepo *models.UserRepository
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAuthService creates a new AuthService
|
||||
func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
|
||||
return &AuthService{
|
||||
config: cfg,
|
||||
userRepo: userRepo,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoginWithWeChatCode handles WeChat login
|
||||
func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Get OpenID and SessionKey from WeChat
|
||||
sessionInfo, err := s.getWeChatSession(code)
|
||||
if err != nil {
|
||||
log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
|
||||
return "", fmt.Errorf("failed to get WeChat session: %w", err)
|
||||
}
|
||||
log.Printf("WeChat session obtained for code %s (took %v)",
|
||||
maskCode(code), time.Since(start))
|
||||
|
||||
// Find or create user
|
||||
user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
|
||||
if err != nil {
|
||||
log.Printf("User lookup failed for OpenID %s: %v",
|
||||
maskOpenID(sessionInfo.OpenID), err)
|
||||
return "", fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
// Create new user
|
||||
user = &models.User{
|
||||
ID: uuid.New().String(),
|
||||
OpenID: sessionInfo.OpenID,
|
||||
SessionKey: sessionInfo.SessionKey,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
log.Printf("User creation failed for OpenID %s: %v",
|
||||
maskOpenID(sessionInfo.OpenID), err)
|
||||
return "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
log.Printf("New user created with ID %s for OpenID %s",
|
||||
user.ID, maskOpenID(sessionInfo.OpenID))
|
||||
} else {
|
||||
// Update session key
|
||||
user.SessionKey = sessionInfo.SessionKey
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("User update failed for ID %s: %v", user.ID, err)
|
||||
return "", fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
log.Printf("User %s session key updated", user.ID)
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := s.generateToken(user)
|
||||
if err != nil {
|
||||
log.Printf("Token generation failed for user %s: %v", user.ID, err)
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("WeChat login completed for user %s (total time %v)",
|
||||
user.ID, time.Since(start))
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// maskCode masks sensitive parts of WeChat code for logging
|
||||
func maskCode(code string) string {
|
||||
if len(code) < 8 {
|
||||
return "****"
|
||||
}
|
||||
return code[:2] + "****" + code[len(code)-2:]
|
||||
}
|
||||
|
||||
// maskOpenID masks sensitive parts of OpenID for logging
|
||||
func maskOpenID(openID string) string {
|
||||
if len(openID) < 8 {
|
||||
return "****"
|
||||
}
|
||||
return openID[:2] + "****" + openID[len(openID)-2:]
|
||||
}
|
||||
|
||||
// getWeChatSession calls WeChat's code2session API
|
||||
func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
|
||||
url := fmt.Sprintf(
|
||||
"https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
||||
s.config.WeChat.AppID,
|
||||
s.config.WeChat.AppSecret,
|
||||
code,
|
||||
)
|
||||
|
||||
resp, err := s.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call WeChat API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result WeChatCode2SessionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// generateToken generates a JWT token for a user
|
||||
func (s *AuthService) generateToken(user *models.User) (string, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
|
||||
"iat": now.Unix(),
|
||||
"iss": s.config.JWT.Issuer,
|
||||
"aud": s.config.JWT.Audience,
|
||||
"token_id": uuid.New().String(), // Unique token ID for tracking
|
||||
}
|
||||
|
||||
// Use stronger signing method
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
||||
signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Token generated for user %s (expires at %v)",
|
||||
user.ID, now.Add(s.config.JWT.ExpireDelta))
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token with additional checks
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
switch {
|
||||
case ve.Errors&jwt.ValidationErrorMalformed != 0:
|
||||
return nil, fmt.Errorf("malformed token")
|
||||
case ve.Errors&jwt.ValidationErrorExpired != 0:
|
||||
return nil, fmt.Errorf("token expired")
|
||||
case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
|
||||
return nil, fmt.Errorf("token not active yet")
|
||||
default:
|
||||
return nil, fmt.Errorf("token validation error: %w", err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// Additional claims validation
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// Check issuer
|
||||
if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
|
||||
return nil, fmt.Errorf("invalid token issuer")
|
||||
}
|
||||
// Check audience
|
||||
if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
|
||||
return nil, fmt.Errorf("invalid token audience")
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetUserFromToken gets user information from a JWT token
|
||||
func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("user_id not found in token")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.FindByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
358
services/otp.go
Normal file
358
services/otp.go
Normal file
|
@ -0,0 +1,358 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"otpm/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OTPService handles OTP related operations
|
||||
type OTPService struct {
|
||||
otpRepo *models.OTPRepository
|
||||
}
|
||||
|
||||
// NewOTPService creates a new OTPService
|
||||
func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
|
||||
return &OTPService{
|
||||
otpRepo: otpRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOTP creates a new OTP with performance monitoring and logging
|
||||
func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Validate input
|
||||
if err := s.validateOTPInput(input); err != nil {
|
||||
log.Printf("OTP validation failed for user %s: %v", userID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Clean and standardize secret
|
||||
secret := cleanSecret(input.Secret)
|
||||
|
||||
// Set defaults for optional fields
|
||||
algorithm := strings.ToUpper(input.Algorithm)
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA1"
|
||||
}
|
||||
|
||||
digits := input.Digits
|
||||
if digits == 0 {
|
||||
digits = 6
|
||||
}
|
||||
|
||||
period := input.Period
|
||||
if period == 0 {
|
||||
period = 30
|
||||
}
|
||||
|
||||
// Create OTP
|
||||
otp := &models.OTP{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
Name: input.Name,
|
||||
Issuer: input.Issuer,
|
||||
Secret: secret,
|
||||
Algorithm: algorithm,
|
||||
Digits: digits,
|
||||
Period: period,
|
||||
}
|
||||
|
||||
if err := s.otpRepo.Create(ctx, otp); err != nil {
|
||||
log.Printf("Failed to create OTP for user %s: %v", userID, err)
|
||||
return nil, fmt.Errorf("failed to create OTP: %w", err)
|
||||
}
|
||||
|
||||
// Log successful creation (without exposing secret)
|
||||
log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
|
||||
otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
|
||||
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// GetOTP gets an OTP by ID
|
||||
func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// ListOTPs lists all OTPs for a user
|
||||
func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
|
||||
otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list OTPs: %w", err)
|
||||
}
|
||||
return otps, nil
|
||||
}
|
||||
|
||||
// UpdateOTP updates an OTP
|
||||
func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
|
||||
// Get existing OTP
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if input.Name != "" {
|
||||
otp.Name = input.Name
|
||||
}
|
||||
if input.Issuer != "" {
|
||||
otp.Issuer = input.Issuer
|
||||
}
|
||||
if input.Algorithm != "" {
|
||||
otp.Algorithm = strings.ToUpper(input.Algorithm)
|
||||
}
|
||||
if input.Digits > 0 {
|
||||
otp.Digits = input.Digits
|
||||
}
|
||||
if input.Period > 0 {
|
||||
otp.Period = input.Period
|
||||
}
|
||||
|
||||
// Validate updated OTP
|
||||
if err := s.validateOTPInput(models.OTPParams{
|
||||
Name: otp.Name,
|
||||
Issuer: otp.Issuer,
|
||||
Secret: otp.Secret,
|
||||
Algorithm: otp.Algorithm,
|
||||
Digits: otp.Digits,
|
||||
Period: otp.Period,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.otpRepo.Update(ctx, otp); err != nil {
|
||||
return nil, fmt.Errorf("failed to update OTP: %w", err)
|
||||
}
|
||||
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// DeleteOTP deletes an OTP
|
||||
func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
|
||||
if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
|
||||
return fmt.Errorf("failed to delete OTP: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode generates a TOTP code with enhanced logging and error handling
|
||||
func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
|
||||
start := time.Now()
|
||||
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
|
||||
return "", 0, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
|
||||
// Get current time step
|
||||
now := time.Now().Unix()
|
||||
timeStep := now / int64(otp.Period)
|
||||
|
||||
// Generate code
|
||||
code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
|
||||
if err != nil {
|
||||
log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
|
||||
return "", 0, fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
|
||||
// Calculate remaining seconds
|
||||
remainingSeconds := otp.Period - int(now%int64(otp.Period))
|
||||
|
||||
// Log successful generation (without actual code)
|
||||
log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
|
||||
id, userID, time.Since(start), remainingSeconds)
|
||||
|
||||
return code, remainingSeconds, nil
|
||||
}
|
||||
|
||||
// VerifyCode verifies a TOTP code with enhanced security and logging
|
||||
func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Basic input validation
|
||||
if len(code) == 0 {
|
||||
log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
|
||||
return false, fmt.Errorf("code is required")
|
||||
}
|
||||
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to find OTP %s for user %s during verification: %v",
|
||||
id, userID, err)
|
||||
return false, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
|
||||
// Get current and adjacent time steps
|
||||
now := time.Now().Unix()
|
||||
timeSteps := []int64{
|
||||
(now - int64(otp.Period)) / int64(otp.Period),
|
||||
now / int64(otp.Period),
|
||||
(now + int64(otp.Period)) / int64(otp.Period),
|
||||
}
|
||||
|
||||
// Check code against all time steps
|
||||
for _, ts := range timeSteps {
|
||||
expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
|
||||
if err != nil {
|
||||
log.Printf("Code generation failed for time step %d: %v", ts, err)
|
||||
continue
|
||||
}
|
||||
if expectedCode == code {
|
||||
// Log successful verification
|
||||
log.Printf("Code verified successfully for OTP %s (user %s) in %v",
|
||||
id, userID, time.Since(start))
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Log failed verification attempt
|
||||
log.Printf("Invalid code provided for OTP %s (user %s) in %v",
|
||||
id, userID, time.Since(start))
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// validateOTPInput validates OTP input with detailed error messages
|
||||
func (s *OTPService) validateOTPInput(input models.OTPParams) error {
|
||||
if input.Name == "" {
|
||||
return fmt.Errorf("name is required")
|
||||
}
|
||||
|
||||
if len(input.Name) > 100 {
|
||||
return fmt.Errorf("name is too long (maximum 100 characters)")
|
||||
}
|
||||
|
||||
if input.Secret == "" {
|
||||
return fmt.Errorf("secret is required")
|
||||
}
|
||||
|
||||
if !isValidBase32(input.Secret) {
|
||||
return fmt.Errorf("invalid secret format: must be a valid base32 string")
|
||||
}
|
||||
|
||||
// Secret length check (after base32 decoding)
|
||||
secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
|
||||
if len(secretBytes) < 10 {
|
||||
return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
|
||||
}
|
||||
|
||||
if input.Algorithm != "" {
|
||||
if !isValidAlgorithm(input.Algorithm) {
|
||||
return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
if input.Digits != 0 {
|
||||
if input.Digits < 6 || input.Digits > 8 {
|
||||
return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
|
||||
}
|
||||
}
|
||||
|
||||
if input.Period != 0 {
|
||||
if input.Period < 30 || input.Period > 60 {
|
||||
return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func cleanSecret(secret string) string {
|
||||
// Remove spaces and convert to upper case
|
||||
secret = strings.TrimSpace(strings.ToUpper(secret))
|
||||
// Remove any padding characters
|
||||
return strings.TrimRight(secret, "=")
|
||||
}
|
||||
|
||||
func isValidBase32(s string) bool {
|
||||
// Try to decode the secret
|
||||
_, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func isValidAlgorithm(algorithm string) bool {
|
||||
switch strings.ToUpper(algorithm) {
|
||||
case "SHA1", "SHA256", "SHA512":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getHasher(algorithm string, key []byte) (hash.Hash, error) {
|
||||
switch strings.ToUpper(algorithm) {
|
||||
case "SHA1":
|
||||
return hmac.New(sha1.New, key), nil
|
||||
case "SHA256":
|
||||
return hmac.New(sha256.New, key), nil
|
||||
case "SHA512":
|
||||
return hmac.New(sha512.New, key), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
|
||||
// Decode secret
|
||||
secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid secret: %w", err)
|
||||
}
|
||||
|
||||
// Get initialized HMAC hasher with secret
|
||||
hasher, err := getHasher(algorithm, secretBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Convert time step to bytes
|
||||
timeBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
|
||||
|
||||
// Calculate HMAC
|
||||
hasher.Write(timeBytes)
|
||||
hash := hasher.Sum(nil)
|
||||
|
||||
// Get offset
|
||||
offset := hash[len(hash)-1] & 0xf
|
||||
|
||||
// Generate 4-byte code
|
||||
code := binary.BigEndian.Uint32(hash[offset : offset+4])
|
||||
code = code & 0x7fffffff
|
||||
|
||||
// Get the specified number of digits
|
||||
code = code % uint32(pow10(digits))
|
||||
|
||||
// Format code with leading zeros
|
||||
return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
|
||||
}
|
||||
|
||||
func pow10(n int) uint32 {
|
||||
result := uint32(1)
|
||||
for i := 0; i < n; i++ {
|
||||
result *= 10
|
||||
}
|
||||
return result
|
||||
}
|
|
@ -1,20 +1,65 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
|
||||
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
|
||||
// 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
// 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
secret := viper.GetString("auth.secret")
|
||||
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method")
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
// 将 openid 存入上下文
|
||||
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// AesDecrypt 函数用于AES解密
|
||||
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
||||
//Base64解码
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
||||
|
@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
|||
return origData, nil
|
||||
}
|
||||
|
||||
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
|
||||
func PKCS7UnPadding(plantText []byte) []byte {
|
||||
// 获取密文的长度
|
||||
length := len(plantText)
|
||||
// 如果密文长度大于0
|
||||
if length > 0 {
|
||||
// 获取最后一个字节的值,即填充的位数
|
||||
unPadding := int(plantText[length-1])
|
||||
// 返回去除填充后的密文
|
||||
return plantText[:(length - unPadding)]
|
||||
}
|
||||
// 如果密文长度为0,则返回原密文
|
||||
return plantText
|
||||
}
|
||||
|
|
159
validator/validator.go
Normal file
159
validator/validator.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
validate *validator.Validate
|
||||
// 自定义验证规则
|
||||
customValidations = map[string]validator.Func{
|
||||
"otpsecret": validateOTPSecret,
|
||||
"password": validatePassword,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
validate = validator.New()
|
||||
|
||||
// 注册自定义验证规则
|
||||
for tag, fn := range customValidations {
|
||||
if err := validate.RegisterValidation(tag, fn); err != nil {
|
||||
panic(fmt.Sprintf("failed to register validation %s: %v", tag, err))
|
||||
}
|
||||
}
|
||||
|
||||
// 使用json tag作为字段名
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateRequest validates a request body against a struct
|
||||
func ValidateRequest(r *http.Request, v interface{}) error {
|
||||
if err := json.NewDecoder(r.Body).Decode(v); err != nil {
|
||||
return fmt.Errorf("invalid request body: %w", err)
|
||||
}
|
||||
|
||||
if err := validate.Struct(v); err != nil {
|
||||
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
||||
return NewValidationError(validationErrors)
|
||||
}
|
||||
return fmt.Errorf("validation error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Fields map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ValidationError) Error() string {
|
||||
var errors []string
|
||||
for field, msg := range e.Fields {
|
||||
errors = append(errors, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return fmt.Sprintf("validation failed: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
// NewValidationError creates a new ValidationError from validator.ValidationErrors
|
||||
func NewValidationError(errors validator.ValidationErrors) *ValidationError {
|
||||
fields := make(map[string]string)
|
||||
for _, err := range errors {
|
||||
fields[err.Field()] = getErrorMessage(err)
|
||||
}
|
||||
return &ValidationError{Fields: fields}
|
||||
}
|
||||
|
||||
// getErrorMessage returns a human-readable error message for a validation error
|
||||
func getErrorMessage(err validator.FieldError) string {
|
||||
switch err.Tag() {
|
||||
case "required":
|
||||
return "This field is required"
|
||||
case "email":
|
||||
return "Invalid email address"
|
||||
case "min":
|
||||
return fmt.Sprintf("Must be at least %s characters long", err.Param())
|
||||
case "max":
|
||||
return fmt.Sprintf("Must be at most %s characters long", err.Param())
|
||||
case "otpsecret":
|
||||
return "Invalid OTP secret format"
|
||||
case "password":
|
||||
return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character"
|
||||
default:
|
||||
return fmt.Sprintf("Failed validation on tag: %s", err.Tag())
|
||||
}
|
||||
}
|
||||
|
||||
// Custom validation functions
|
||||
|
||||
// validateOTPSecret validates an OTP secret
|
||||
func validateOTPSecret(fl validator.FieldLevel) bool {
|
||||
secret := fl.Field().String()
|
||||
// OTP secret should be base32 encoded
|
||||
matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret)
|
||||
return matched
|
||||
}
|
||||
|
||||
// validatePassword validates a password
|
||||
func validatePassword(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
// At least 8 characters long
|
||||
if len(password) < 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password)
|
||||
hasLower = regexp.MustCompile(`[a-z]`).MatchString(password)
|
||||
hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password)
|
||||
hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
|
||||
)
|
||||
|
||||
return hasUpper && hasLower && hasNumber && hasSpecial
|
||||
}
|
||||
|
||||
// Request validation structs
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code" validate:"required"`
|
||||
}
|
||||
|
||||
// CreateOTPRequest represents a request to create an OTP
|
||||
type CreateOTPRequest struct {
|
||||
Name string `json:"name" validate:"required,min=1,max=100"`
|
||||
Issuer string `json:"issuer" validate:"required,min=1,max=100"`
|
||||
Secret string `json:"secret" validate:"required,otpsecret"`
|
||||
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
|
||||
Digits int `json:"digits" validate:"required,oneof=6 8"`
|
||||
Period int `json:"period" validate:"required,oneof=30 60"`
|
||||
}
|
||||
|
||||
// UpdateOTPRequest represents a request to update an OTP
|
||||
type UpdateOTPRequest struct {
|
||||
Name string `json:"name" validate:"omitempty,min=1,max=100"`
|
||||
Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"`
|
||||
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
|
||||
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
|
||||
Period int `json:"period" validate:"omitempty,oneof=30 60"`
|
||||
}
|
||||
|
||||
// VerifyOTPRequest represents a request to verify an OTP code
|
||||
type VerifyOTPRequest struct {
|
||||
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue