fix api
This commit is contained in:
parent
01b8951dd5
commit
10ebc59ffb
17 changed files with 1087 additions and 238 deletions
207
api_server.go
207
api_server.go
|
@ -1,18 +1,80 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"auth"
|
||||
"config"
|
||||
"otpm/auth"
|
||||
"otpm/config"
|
||||
"otpm/db"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
//go:embed init/sqlite/init.sql
|
||||
//go:embed init/postgresql/init.sql
|
||||
var initScripts embed.FS
|
||||
|
||||
// 全局数据库连接
|
||||
var tokenDB *db.DB
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB(dbConfig config.DatabaseConfig) error {
|
||||
var err error
|
||||
|
||||
// 创建数据库配置
|
||||
dbCfg := db.Config{
|
||||
Driver: dbConfig.Driver,
|
||||
}
|
||||
|
||||
// 根据驱动类型设置配置
|
||||
if dbConfig.Driver == "postgres" {
|
||||
dbCfg.Host = dbConfig.Postgres.Host
|
||||
dbCfg.Port = dbConfig.Postgres.Port
|
||||
dbCfg.User = dbConfig.Postgres.User
|
||||
dbCfg.Password = dbConfig.Postgres.Password
|
||||
dbCfg.DBName = dbConfig.Postgres.DBName
|
||||
dbCfg.SSLMode = dbConfig.Postgres.SSLMode
|
||||
} else if dbConfig.Driver == "sqlite3" {
|
||||
dbCfg.DBName = dbConfig.SQLite.Path
|
||||
}
|
||||
|
||||
// 使用db包创建数据库连接
|
||||
tokenDB, err = db.New(dbCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening database: %v", err)
|
||||
}
|
||||
|
||||
// 读取并执行初始化脚本
|
||||
var scriptPath string
|
||||
if dbConfig.Driver == "sqlite3" {
|
||||
scriptPath = "init/sqlite/init.sql"
|
||||
} else if dbConfig.Driver == "postgres" {
|
||||
scriptPath = "init/postgresql/init.sql"
|
||||
} else {
|
||||
return fmt.Errorf("unsupported database driver: %s", dbConfig.Driver)
|
||||
}
|
||||
|
||||
script, err := initScripts.ReadFile(scriptPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading init script: %v", err)
|
||||
}
|
||||
|
||||
_, err = tokenDB.Exec(string(script))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error executing init script: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Server represents the API server
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
|
@ -33,6 +95,7 @@ func NewServer(cfg *config.Config) *Server {
|
|||
config: cfg,
|
||||
router: mux.NewRouter(),
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
return s
|
||||
}
|
||||
|
@ -41,12 +104,15 @@ func NewServer(cfg *config.Config) *Server {
|
|||
func (s *Server) setupRoutes() {
|
||||
// 公开路由
|
||||
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
|
||||
s.router.HandleFunc("/auth/refresh", s.RefreshTokenHandler)
|
||||
|
||||
// 受保护路由(需要JWT)
|
||||
authRouter := s.router.PathPrefix("").Subrouter()
|
||||
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
|
||||
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
|
||||
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
|
||||
authRouter.HandleFunc("/otp/delete", DeleteTokenHandler).Methods("POST")
|
||||
authRouter.HandleFunc("/otp/clear_all", ClearAllHandler).Methods("POST")
|
||||
|
||||
// 添加CORS中间件
|
||||
s.router.Use(s.corsMiddleware)
|
||||
|
@ -56,22 +122,45 @@ func (s *Server) setupRoutes() {
|
|||
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
referer := r.Header.Get("Referer")
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
|
||||
// Check if the origin is allowed
|
||||
allowed := false
|
||||
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
|
||||
if origin == allowedOrigin {
|
||||
allowed = true
|
||||
break
|
||||
// 检查是否是微信小程序请求
|
||||
isWechatMiniProgram := false
|
||||
if userAgent != "" && (regexp.MustCompile(`MicroMessenger`).MatchString(userAgent) ||
|
||||
regexp.MustCompile(`miniProgram`).MatchString(userAgent)) {
|
||||
isWechatMiniProgram = true
|
||||
}
|
||||
|
||||
// 从Referer中检查是否是微信小程序请求
|
||||
if referer != "" && regexp.MustCompile(`^https://servicewechat\.com/`).MatchString(referer) {
|
||||
isWechatMiniProgram = true
|
||||
}
|
||||
|
||||
// 检查Origin是否允许
|
||||
allowed := isWechatMiniProgram // 如果是微信小程序请求,默认允许
|
||||
if !allowed && origin != "" {
|
||||
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
|
||||
if origin == allowedOrigin || allowedOrigin == "*" {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowed {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
// 如果是微信小程序请求且没有Origin头,使用通配符
|
||||
if isWechatMiniProgram && origin == "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Methods",
|
||||
joinStrings(s.config.CORS.AllowedMethods))
|
||||
w.Header().Set("Access-Control-Allow-Headers",
|
||||
joinStrings(s.config.CORS.AllowedHeaders))
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
|
@ -96,13 +185,14 @@ func joinStrings(slice []string) string {
|
|||
}
|
||||
|
||||
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
|
@ -127,7 +217,7 @@ func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Wechat service error", http.StatusInternalServerError)
|
||||
return
|
||||
|
@ -144,25 +234,91 @@ func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := s.generateSessionToken(wechatResp.OpenID)
|
||||
// 生成访问令牌和刷新令牌
|
||||
accessToken, refreshToken, err := s.generateSessionTokens(wechatResp.OpenID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
|
||||
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := map[string]interface{}{
|
||||
"token": token,
|
||||
"openid": wechatResp.OpenID,
|
||||
"access_token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
"openid": wechatResp.OpenID,
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) generateSessionToken(openid string) (string, error) {
|
||||
return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
|
||||
func (s *Server) generateSessionTokens(openid string) (accessToken string, refreshToken string, err error) {
|
||||
accessToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
refreshToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.RefreshTokenExpiry)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func (s *Server) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证刷新令牌
|
||||
claims, err := auth.ValidateToken(req.RefreshToken, s.config.Security.JWTSigningKey)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 从刷新令牌中获取 openid
|
||||
openid := claims.UserID
|
||||
if openid == "" {
|
||||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新的访问令牌和刷新令牌
|
||||
accessToken, refreshToken, err := s.generateSessionTokens(openid)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
"openid": openid,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
|
@ -181,8 +337,12 @@ func (s *Server) Start() error {
|
|||
}
|
||||
|
||||
func main() {
|
||||
// 定义命令行参数
|
||||
configPath := flag.String("config", "config", "Path to configuration file (without extension)")
|
||||
flag.Parse()
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.LoadConfig("config")
|
||||
cfg, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
@ -194,6 +354,13 @@ func main() {
|
|||
}
|
||||
log.Println("Database connection established successfully")
|
||||
|
||||
// 初始化API
|
||||
log.Println("Initializing API...")
|
||||
if err := InitAPI(cfg); err != nil {
|
||||
log.Fatalf("Failed to initialize API: %v", err)
|
||||
}
|
||||
log.Println("API initialized successfully")
|
||||
|
||||
// 创建并启动服务器
|
||||
server := NewServer(cfg)
|
||||
if err := server.Start(); err != nil {
|
||||
|
|
|
@ -96,6 +96,28 @@ func GetUserIDFromContext(ctx context.Context) (string, error) {
|
|||
return userID, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token and returns the claims
|
||||
func ValidateToken(tokenString string, signingKey string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(signingKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RequireAuth is a middleware that ensures a valid JWT token is present
|
||||
func RequireAuth(signingKey string, next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
73
auth/refresh.go
Normal file
73
auth/refresh.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// RefreshRequest represents the request body for token refresh
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// RefreshResponse represents the response body for token refresh
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// HandleRefresh handles the token refresh request
|
||||
func HandleRefresh(w http.ResponseWriter, r *http.Request, signingKey string, accessExpiry, refreshExpiry time.Duration) {
|
||||
// Only accept POST requests
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req RefreshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate refresh token
|
||||
claims := &Claims{}
|
||||
token, err := jwt.ParseWithClaims(req.RefreshToken, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(signingKey), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate new access token
|
||||
accessToken, err := GenerateToken(claims.UserID, signingKey, accessExpiry)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate access token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate new refresh token
|
||||
refreshToken, err := GenerateToken(claims.UserID, signingKey, refreshExpiry)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate refresh token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return new tokens
|
||||
resp := RefreshResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
|
@ -18,21 +18,22 @@ database:
|
|||
|
||||
# Security Configuration
|
||||
security:
|
||||
encryption_key: "your-32-byte-encryption-key-here"
|
||||
jwt_signing_key: "your-jwt-signing-key-here"
|
||||
encryption_key: "12345678901234567890123456789012"
|
||||
jwt_signing_key: "jwt_secret_key_for_authentication_12345"
|
||||
token_expiry: 24h
|
||||
refresh_token_expiry: 168h # 7 days
|
||||
|
||||
# WeChat Configuration
|
||||
wechat:
|
||||
app_id: "YOUR_APPID"
|
||||
app_secret: "YOUR_APPSECRET"
|
||||
app_id: "wx57d1033974eb5250"
|
||||
app_secret: "be494c2a81df685a40b9a74e1736b15d"
|
||||
|
||||
# CORS Configuration
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "http://localhost:8080"
|
||||
- "https://yourdomain.com"
|
||||
- "https://servicewechat.com/wx57d1033974eb5250"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
|
|
|
@ -9,54 +9,54 @@ import (
|
|||
|
||||
// Config holds all configuration for our application
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Security SecurityConfig
|
||||
CORS CORSConfig
|
||||
Wechat WechatConfig
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Wechat WechatConfig `mapstructure:"wechat"`
|
||||
}
|
||||
|
||||
// ServerConfig holds all server related configuration
|
||||
type ServerConfig struct {
|
||||
Port int
|
||||
Timeout time.Duration
|
||||
Port int `mapstructure:"port"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds all database related configuration
|
||||
type DatabaseConfig struct {
|
||||
Driver string
|
||||
SQLite SQLiteConfig
|
||||
Postgres PostgresConfig
|
||||
Driver string `mapstructure:"driver"`
|
||||
SQLite SQLiteConfig `mapstructure:"sqlite"`
|
||||
Postgres PostgresConfig `mapstructure:"postgres"`
|
||||
}
|
||||
|
||||
// SQLiteConfig holds SQLite specific configuration
|
||||
type SQLiteConfig struct {
|
||||
Path string
|
||||
Path string `mapstructure:"path"`
|
||||
}
|
||||
|
||||
// PostgresConfig holds PostgreSQL specific configuration
|
||||
type PostgresConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
DBName string `mapstructure:"dbname"`
|
||||
SSLMode string `mapstructure:"sslmode"`
|
||||
}
|
||||
|
||||
// SecurityConfig holds all security related configuration
|
||||
type SecurityConfig struct {
|
||||
EncryptionKey string
|
||||
JWTSigningKey string
|
||||
TokenExpiry time.Duration
|
||||
RefreshTokenExpiry time.Duration
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
JWTSigningKey string `mapstructure:"jwt_signing_key"`
|
||||
TokenExpiry time.Duration `mapstructure:"token_expiry"`
|
||||
RefreshTokenExpiry time.Duration `mapstructure:"refresh_token_expiry"`
|
||||
}
|
||||
|
||||
// CORSConfig holds CORS related configuration
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowedHeaders []string
|
||||
AllowedOrigins []string `mapstructure:"allowed_origins"`
|
||||
AllowedMethods []string `mapstructure:"allowed_methods"`
|
||||
AllowedHeaders []string `mapstructure:"allowed_headers"`
|
||||
}
|
||||
|
||||
// WechatConfig holds WeChat related configuration
|
||||
|
@ -69,10 +69,16 @@ type WechatConfig struct {
|
|||
func LoadConfig(configPath string) (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(configPath)
|
||||
v.AddConfigPath(".")
|
||||
// 检查配置路径是否包含扩展名
|
||||
if len(configPath) > 5 && (configPath[len(configPath)-5:] == ".yaml" || configPath[len(configPath)-4:] == ".yml") {
|
||||
// 如果包含扩展名,直接使用完整路径
|
||||
v.SetConfigFile(configPath)
|
||||
} else {
|
||||
// 否则按照传统方式处理
|
||||
v.SetConfigName(configPath)
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
}
|
||||
|
||||
// Read environment variables
|
||||
v.AutomaticEnv()
|
||||
|
|
44
config/config.yaml
Normal file
44
config/config.yaml
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Server Configuration
|
||||
server:
|
||||
port: 8080
|
||||
timeout: 30s
|
||||
|
||||
# Database Configuration
|
||||
database:
|
||||
driver: "sqlite3" # or "postgres"
|
||||
sqlite:
|
||||
path: "./data.db"
|
||||
postgres:
|
||||
host: "localhost"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "password"
|
||||
dbname: "otpdb"
|
||||
sslmode: "disable"
|
||||
|
||||
# Security Configuration
|
||||
security:
|
||||
encryption_key: "12345678901234567890123456789012"
|
||||
jwt_signing_key: "jwt_secret_key_for_authentication_12345"
|
||||
token_expiry: 24h
|
||||
refresh_token_expiry: 168h # 7 days
|
||||
|
||||
# WeChat Configuration
|
||||
wechat:
|
||||
app_id: "wx57d1033974eb5250"
|
||||
app_secret: "be494c2a81df685a40b9a74e1736b15d"
|
||||
|
||||
# CORS Configuration
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "http://localhost:8080"
|
||||
- "https://yourdomain.com"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
- "OPTIONS"
|
||||
allowed_headers:
|
||||
- "Authorization"
|
||||
- "Content-Type"
|
68
crypto/token.go
Normal file
68
crypto/token.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// TokenEncryptor 用于加密和解密令牌的结构体
|
||||
type TokenEncryptor struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewTokenEncryptor 创建一个新的TokenEncryptor实例
|
||||
func NewTokenEncryptor(key []byte) (*TokenEncryptor, error) {
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("key must be 32 bytes")
|
||||
}
|
||||
return &TokenEncryptor{key: key}, nil
|
||||
}
|
||||
|
||||
// EncryptTokenSecret 加密令牌密钥
|
||||
func (te *TokenEncryptor) EncryptTokenSecret(secret string) (string, error) {
|
||||
block, err := aes.NewCipher(te.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating cipher: %v", err)
|
||||
}
|
||||
|
||||
plaintext := []byte(secret)
|
||||
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return "", fmt.Errorf("error generating IV: %v", err)
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptTokenSecret 解密令牌密钥
|
||||
func (te *TokenEncryptor) DecryptTokenSecret(encrypted string) (string, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(te.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating cipher: %v", err)
|
||||
}
|
||||
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
ciphertext = ciphertext[aes.BlockSize:]
|
||||
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext, ciphertext)
|
||||
|
||||
return string(ciphertext), nil
|
||||
}
|
52
db/db.go
Normal file
52
db/db.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
Driver string
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
}
|
||||
|
||||
// DB 封装数据库操作
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// New 创建数据库连接
|
||||
func New(cfg Config) (*DB, error) {
|
||||
var dsn string
|
||||
switch cfg.Driver {
|
||||
case "postgres":
|
||||
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode)
|
||||
case "sqlite3":
|
||||
dsn = cfg.DBName
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
|
||||
}
|
||||
|
||||
db, err := sql.Open(cfg.Driver, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
return &DB{db}, nil
|
||||
}
|
224
db/token.go
Normal file
224
db/token.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 令牌类型常量
|
||||
const (
|
||||
TokenTypeHOTP = "hotp"
|
||||
TokenTypeTOTP = "totp"
|
||||
)
|
||||
|
||||
// Token 表示一个 OTP 令牌
|
||||
type Token struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Issuer string `json:"issuer"`
|
||||
Account string `json:"account"`
|
||||
Secret string `json:"secret"`
|
||||
Type string `json:"type"`
|
||||
Counter *int `json:"counter,omitempty"`
|
||||
Period int `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// SaveTokens 保存用户的令牌列表
|
||||
func (db *DB) SaveTokens(userID string, tokens []Token) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 准备插入语句
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO tokens (
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
|
||||
) ON CONFLICT (id, user_id) DO UPDATE SET
|
||||
issuer = EXCLUDED.issuer,
|
||||
account = EXCLUDED.account,
|
||||
secret = EXCLUDED.secret,
|
||||
type = EXCLUDED.type,
|
||||
counter = EXCLUDED.counter,
|
||||
period = EXCLUDED.period,
|
||||
digits = EXCLUDED.digits,
|
||||
algorithm = EXCLUDED.algorithm,
|
||||
timestamp = EXCLUDED.timestamp,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
now := time.Now()
|
||||
for _, token := range tokens {
|
||||
// 确保algorithm字段有默认值
|
||||
algorithm := token.Algorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA1"
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(
|
||||
token.ID,
|
||||
userID,
|
||||
token.Issuer,
|
||||
token.Account,
|
||||
token.Secret,
|
||||
token.Type,
|
||||
token.Counter,
|
||||
token.Period,
|
||||
token.Digits,
|
||||
algorithm,
|
||||
token.Timestamp,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetTokensByUserID 获取用户的所有令牌
|
||||
func (db *DB) GetTokensByUserID(userID string) ([]Token, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1
|
||||
ORDER BY timestamp DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []Token
|
||||
for rows.Next() {
|
||||
var token Token
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.UserID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&token.Secret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algorithm,
|
||||
&token.Timestamp,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// DeleteToken 删除指定的令牌
|
||||
func (db *DB) DeleteToken(userID, tokenID string) error {
|
||||
result, err := db.Exec(`
|
||||
DELETE FROM tokens
|
||||
WHERE user_id = $1 AND id = $2
|
||||
`, userID, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTokenCounter 更新 HOTP 令牌的计数器
|
||||
func (db *DB) UpdateTokenCounter(userID, tokenID string, counter int) error {
|
||||
result, err := db.Exec(`
|
||||
UPDATE tokens
|
||||
SET counter = $1, updated_at = $2
|
||||
WHERE user_id = $3 AND id = $4 AND type = $5
|
||||
`, counter, time.Now(), userID, tokenID, TokenTypeHOTP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenByID 获取指定的令牌
|
||||
func (db *DB) GetTokenByID(userID, tokenID string) (*Token, error) {
|
||||
var token Token
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1 AND id = $2
|
||||
`, userID, tokenID).Scan(
|
||||
&token.ID,
|
||||
&token.UserID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&token.Secret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algorithm,
|
||||
&token.Timestamp,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
28
go.mod
Normal file
28
go.mod
Normal file
|
@ -0,0 +1,28 @@
|
|||
module otpm
|
||||
|
||||
go 1.21.1
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/spf13/viper v1.20.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
60
go.sum
Normal file
60
go.sum
Normal file
|
@ -0,0 +1,60 @@
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
|
||||
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
|
||||
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
|
||||
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS tokens (
|
|||
counter INTEGER, -- HOTP计数器(可选)
|
||||
period INTEGER NOT NULL, -- TOTP周期(秒)
|
||||
digits INTEGER NOT NULL, -- 验证码位数
|
||||
algo VARCHAR(10) NOT NULL, -- 使用的哈希算法
|
||||
algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1', -- 使用的哈希算法
|
||||
timestamp BIGINT NOT NULL, -- 最后更新时间戳
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
|
@ -45,7 +45,7 @@ COMMENT ON COLUMN tokens.type IS '令牌类型(totp/hotp)';
|
|||
COMMENT ON COLUMN tokens.counter IS 'HOTP计数器(可选)';
|
||||
COMMENT ON COLUMN tokens.period IS 'TOTP周期(秒)';
|
||||
COMMENT ON COLUMN tokens.digits IS '验证码位数';
|
||||
COMMENT ON COLUMN tokens.algo IS '使用的哈希算法';
|
||||
COMMENT ON COLUMN tokens.algorithm IS '使用的哈希算法';
|
||||
COMMENT ON COLUMN tokens.timestamp IS '最后更新时间戳';
|
||||
COMMENT ON COLUMN tokens.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN tokens.updated_at IS '最后更新时间';
|
|
@ -9,22 +9,22 @@ PRAGMA foreign_keys = ON;
|
|||
|
||||
-- 创建tokens表
|
||||
CREATE TABLE IF NOT EXISTS tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
issuer TEXT NOT NULL,
|
||||
account TEXT NOT NULL,
|
||||
secret TEXT NOT NULL CHECK (length(secret) >= 16 AND secret REGEXP '^[A-Z2-7]+=*$'),
|
||||
type TEXT NOT NULL CHECK (type IN ('HOTP', 'TOTP')),
|
||||
secret TEXT NOT NULL CHECK (length(secret) >= 16),
|
||||
type TEXT NOT NULL CHECK (type IN ('hotp', 'totp')),
|
||||
counter INTEGER CHECK (
|
||||
(type = 'HOTP' AND counter >= 0) OR
|
||||
(type = 'TOTP' AND counter IS NULL)
|
||||
(type = 'hotp' AND counter >= 0) OR
|
||||
(type = 'totp' AND counter IS NULL)
|
||||
),
|
||||
period INTEGER DEFAULT 30 CHECK (
|
||||
(type = 'TOTP' AND period >= 30) OR
|
||||
(type = 'HOTP' AND period IS NULL)
|
||||
(type = 'totp' AND period >= 30) OR
|
||||
(type = 'hotp' AND period IS NULL)
|
||||
),
|
||||
digits INTEGER NOT NULL DEFAULT 6 CHECK (digits IN (6, 8)),
|
||||
algo TEXT NOT NULL DEFAULT 'SHA1' CHECK (algo IN ('SHA1', 'SHA256', 'SHA512')),
|
||||
algorithm TEXT NOT NULL DEFAULT 'SHA1' CHECK (algorithm IN ('SHA1', 'SHA256', 'SHA512')),
|
||||
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
updated_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
UNIQUE(user_id, issuer, account)
|
||||
|
@ -33,16 +33,16 @@ CREATE TABLE IF NOT EXISTS tokens (
|
|||
-- 基本索引
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_lookup ON tokens(user_id, issuer, account);
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'HOTP';
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'TOTP';
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'hotp';
|
||||
CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'totp';
|
||||
|
||||
-- 简化统计视图
|
||||
CREATE VIEW IF NOT EXISTS v_token_stats AS
|
||||
SELECT
|
||||
user_id,
|
||||
COUNT(*) as total_tokens,
|
||||
SUM(type = 'HOTP') as hotp_count,
|
||||
SUM(type = 'TOTP') as totp_count
|
||||
SUM(type = 'hotp') as hotp_count,
|
||||
SUM(type = 'totp') as totp_count
|
||||
FROM tokens
|
||||
GROUP BY user_id;
|
||||
|
329
otp_api.go
329
otp_api.go
|
@ -1,65 +1,36 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"otpm/config"
|
||||
"otpm/crypto"
|
||||
"otpm/db"
|
||||
"otpm/utils"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// 加密密钥(32字节AES-256)
|
||||
var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
|
||||
var (
|
||||
tokenEncryptor *crypto.TokenEncryptor
|
||||
)
|
||||
|
||||
// encryptTokenSecret 加密令牌密钥
|
||||
func encryptTokenSecret(secret string) (string, error) {
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
// InitAPI 初始化API相关配置
|
||||
func InitAPI(cfg *config.Config) error {
|
||||
var err error
|
||||
|
||||
// 初始化token加密器
|
||||
tokenEncryptor, err = crypto.NewTokenEncryptor([]byte(cfg.Security.EncryptionKey))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return fmt.Errorf("failed to initialize token encryptor: %v", err)
|
||||
}
|
||||
|
||||
ciphertext := make([]byte, aes.BlockSize+len(secret))
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// decryptTokenSecret 解密令牌密钥
|
||||
func decryptTokenSecret(encrypted string) (string, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
ciphertext = ciphertext[aes.BlockSize:]
|
||||
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext, ciphertext)
|
||||
|
||||
return string(ciphertext), nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveRequest 保存请求的数据结构
|
||||
|
@ -71,15 +42,15 @@ type SaveRequest struct {
|
|||
|
||||
// TokenData token数据结构
|
||||
type TokenData struct {
|
||||
ID string `json:"id"`
|
||||
Issuer string `json:"issuer"`
|
||||
Account string `json:"account"`
|
||||
Secret string `json:"secret"`
|
||||
Type string `json:"type"`
|
||||
Counter int `json:"counter,omitempty"`
|
||||
Period int `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Algo string `json:"algo"`
|
||||
ID string `json:"id"`
|
||||
Issuer string `json:"issuer"`
|
||||
Account string `json:"account"`
|
||||
Secret string `json:"secret"`
|
||||
Type string `json:"type"`
|
||||
Counter *int `json:"counter,omitempty"`
|
||||
Period int `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
}
|
||||
|
||||
// SaveResponse 保存响应的数据结构
|
||||
|
@ -101,37 +72,10 @@ type RecoverResponse struct {
|
|||
} `json:"data"`
|
||||
}
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB() error {
|
||||
connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
|
||||
var err error
|
||||
db, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening database: %v", err)
|
||||
}
|
||||
|
||||
if err = db.Ping(); err != nil {
|
||||
return fmt.Errorf("error connecting to the database: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
// 使用api_server.go中定义的全局数据库连接
|
||||
|
||||
// SaveHandler 保存token的接口处理函数
|
||||
func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
// 处理OPTIONS请求
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
@ -142,26 +86,26 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
|||
var req SaveRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求数据
|
||||
if req.UserID == "" {
|
||||
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Tokens) == 0 {
|
||||
sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 开始数据库事务
|
||||
tx, err := db.Begin()
|
||||
tx, err := tokenDB.Begin()
|
||||
if err != nil {
|
||||
log.Printf("Error starting transaction: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
@ -170,47 +114,68 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
|||
_, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("Error deleting existing tokens: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 插入新的tokens
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algorithm)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
`)
|
||||
if err != nil {
|
||||
log.Printf("Error preparing statement: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, token := range req.Tokens {
|
||||
// 加密secret
|
||||
encryptedSecret, err := encryptTokenSecret(token.Secret)
|
||||
encryptedSecret, err := tokenEncryptor.EncryptTokenSecret(token.Secret)
|
||||
if err != nil {
|
||||
log.Printf("Error encrypting token secret: %v", err)
|
||||
sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 确定令牌类型
|
||||
tokenType := token.Type
|
||||
if tokenType != db.TokenTypeHOTP && tokenType != db.TokenTypeTOTP {
|
||||
log.Printf("Invalid token type: %s", tokenType)
|
||||
utils.SendErrorResponse(w, "Invalid token type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 对于TOTP类型的令牌,counter必须为NULL
|
||||
var counterValue interface{}
|
||||
if tokenType == db.TokenTypeTOTP {
|
||||
counterValue = nil
|
||||
} else {
|
||||
counterValue = token.Counter
|
||||
}
|
||||
|
||||
// 确保algo字段有默认值
|
||||
algorithm := token.Algorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA1"
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(
|
||||
token.ID,
|
||||
req.UserID,
|
||||
token.Issuer,
|
||||
token.Account,
|
||||
encryptedSecret,
|
||||
token.Type,
|
||||
token.Counter,
|
||||
tokenType,
|
||||
counterValue,
|
||||
token.Period,
|
||||
token.Digits,
|
||||
token.Algo,
|
||||
req.Timestamp,
|
||||
algorithm,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Error inserting token: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -218,7 +183,7 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
|||
// 提交事务
|
||||
if err = tx.Commit(); err != nil {
|
||||
log.Printf("Error committing transaction: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -229,22 +194,11 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
resp.Data.ID = req.UserID
|
||||
|
||||
sendJSONResponse(w, resp, http.StatusOK)
|
||||
utils.SendJSONResponse(w, resp)
|
||||
}
|
||||
|
||||
// RecoverHandler 恢复token的接口处理函数
|
||||
func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
// 处理OPTIONS请求
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
@ -257,33 +211,33 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证用户ID
|
||||
if req.UserID == "" {
|
||||
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 查询数据库
|
||||
rows, err := db.Query(`
|
||||
SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
|
||||
rows, err := tokenDB.Query(`
|
||||
SELECT id, issuer, account, secret, type, counter, period, digits, algorithm, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1
|
||||
ORDER BY timestamp DESC
|
||||
ORDER BY updated_at DESC
|
||||
`, req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("Error querying database: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 读取查询结果
|
||||
var tokens []TokenData
|
||||
var timestamp int64
|
||||
var updatedAt string
|
||||
for rows.Next() {
|
||||
var token TokenData
|
||||
var encryptedSecret string
|
||||
|
@ -296,20 +250,20 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
|||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algo,
|
||||
×tamp,
|
||||
&token.Algorithm,
|
||||
&updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Error scanning row: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 解密secret
|
||||
token.Secret, err = decryptTokenSecret(encryptedSecret)
|
||||
token.Secret, err = tokenEncryptor.DecryptTokenSecret(encryptedSecret)
|
||||
if err != nil {
|
||||
log.Printf("Error decrypting token secret: %v", err)
|
||||
sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
|
@ -317,7 +271,22 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
if err = rows.Err(); err != nil {
|
||||
log.Printf("Error iterating rows: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理无数据情况
|
||||
if len(tokens) == 0 {
|
||||
resp := RecoverResponse{
|
||||
Success: true,
|
||||
Message: "No tokens found in cloud",
|
||||
}
|
||||
resp.Data.Tokens = []TokenData{}
|
||||
resp.Data.Timestamp = time.Now().Unix()
|
||||
|
||||
// 返回404状态码但保持success:true
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
utils.SendJSONResponse(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -327,34 +296,78 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
|||
Message: "Tokens recovered successfully",
|
||||
}
|
||||
resp.Data.Tokens = tokens
|
||||
|
||||
// 将updatedAt转换为时间戳
|
||||
var timestamp int64
|
||||
// 首先尝试将updatedAt解析为整数时间戳
|
||||
if ts, err := utils.ParseInt64(updatedAt); err == nil {
|
||||
timestamp = ts
|
||||
} else if t, err := time.Parse(time.RFC3339, updatedAt); err == nil {
|
||||
// 尝试解析为RFC3339格式
|
||||
timestamp = t.Unix()
|
||||
} else if t, err := time.Parse("2006-01-02 15:04:05", updatedAt); err == nil {
|
||||
// 尝试解析为标准日期时间格式
|
||||
timestamp = t.Unix()
|
||||
} else {
|
||||
// 如果所有解析方法都失败,使用当前时间
|
||||
timestamp = time.Now().Unix()
|
||||
log.Printf("Error parsing time string: %v", err)
|
||||
}
|
||||
|
||||
resp.Data.Timestamp = timestamp
|
||||
|
||||
sendJSONResponse(w, resp, http.StatusOK)
|
||||
utils.SendJSONResponse(w, resp)
|
||||
}
|
||||
|
||||
// sendErrorResponse 发送错误响应
|
||||
func sendErrorResponse(w http.ResponseWriter, message string, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": false,
|
||||
"message": message,
|
||||
// ClearAllHandler 删除用户所有token数据的接口处理函数
|
||||
func ClearAllHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req struct {
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求数据
|
||||
if req.UserID == "" {
|
||||
utils.SendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行删除操作
|
||||
result, err := tokenDB.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("Error deleting tokens: %v", err)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否真的删除了记录
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
log.Printf("Error getting rows affected: %v", err)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
utils.SendJSONResponse(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Deleted %d tokens successfully", rowsAffected),
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteTokenHandler 删除单个token的接口处理函数
|
||||
func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
// 处理OPTIONS请求
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
@ -368,21 +381,21 @@ func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求数据
|
||||
if req.UserID == "" || req.TokenID == "" {
|
||||
sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
|
||||
utils.SendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行删除操作
|
||||
result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
|
||||
result, err := tokenDB.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
|
||||
if err != nil {
|
||||
log.Printf("Error deleting token: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -390,28 +403,18 @@ func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
|
|||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
log.Printf("Error getting rows affected: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
sendErrorResponse(w, "Token not found", http.StatusNotFound)
|
||||
utils.SendErrorResponse(w, "Token not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
sendJSONResponse(w, map[string]interface{}{
|
||||
utils.SendJSONResponse(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Token deleted successfully",
|
||||
}, http.StatusOK)
|
||||
}
|
||||
|
||||
// sendJSONResponse 发送JSON响应
|
||||
func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Printf("Error encoding response: %v", err)
|
||||
sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -15,17 +15,27 @@ func TestSaveHandler(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
// 准备测试数据
|
||||
counter := 0 // 创建一个int变量
|
||||
testData := SaveRequest{
|
||||
UserID: "test_user_123",
|
||||
Tokens: []TokenData{
|
||||
{
|
||||
Issuer: "TestOrg",
|
||||
Account: "user@test.com",
|
||||
Secret: "JBSWY3DPEHPK3PXP",
|
||||
Type: "totp",
|
||||
Period: 30,
|
||||
Digits: 6,
|
||||
Algo: "SHA1",
|
||||
Issuer: "TestOrg",
|
||||
Account: "user@test.com",
|
||||
Secret: "JBSWY3DPEHPK3PXP",
|
||||
Type: "totp",
|
||||
Period: 30,
|
||||
Digits: 6,
|
||||
Algorithm: "SHA1",
|
||||
},
|
||||
{
|
||||
Issuer: "TestOrgHOTP",
|
||||
Account: "user@test.com",
|
||||
Secret: "JBSWY3DPEHPK3PXP",
|
||||
Type: "hotp",
|
||||
Counter: &counter, // 使用指针
|
||||
Digits: 6,
|
||||
Algorithm: "SHA1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -105,7 +115,7 @@ func TestRecoverHandler(t *testing.T) {
|
|||
if recoverResp.Message != "Tokens recovered successfully" {
|
||||
t.Errorf("Expected message to be 'Tokens recovered successfully', got '%s'\n", recoverResp.Message)
|
||||
}
|
||||
if len(recoverResp.Tokens) != 1 {
|
||||
t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Tokens))
|
||||
if len(recoverResp.Data.Tokens) != 1 {
|
||||
t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Data.Tokens))
|
||||
}
|
||||
}
|
||||
|
|
34
utils/http.go
Normal file
34
utils/http.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SendErrorResponse 发送错误响应
|
||||
func SendErrorResponse(w http.ResponseWriter, message string, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": false,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// SendJSONResponse 发送JSON响应
|
||||
func SendJSONResponse(w http.ResponseWriter, data interface{}, status ...int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// 默认状态码为200 OK
|
||||
statusCode := http.StatusOK
|
||||
if len(status) > 0 {
|
||||
statusCode = status[0]
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Printf("Error encoding response: %v", err)
|
||||
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
57
utils/validation.go
Normal file
57
utils/validation.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ParseInt64 将字符串解析为int64类型
|
||||
func ParseInt64(s string) (int64, error) {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌数据的有效性
|
||||
func ValidateToken(id, issuer, account, secret, tokenType, algo string, counter, period, digits int) error {
|
||||
if id == "" {
|
||||
return errors.New("token ID is required")
|
||||
}
|
||||
if issuer == "" {
|
||||
return errors.New("issuer is required")
|
||||
}
|
||||
if account == "" {
|
||||
return errors.New("account is required")
|
||||
}
|
||||
if secret == "" {
|
||||
return errors.New("secret is required")
|
||||
}
|
||||
|
||||
// 验证令牌类型
|
||||
if tokenType != "totp" && tokenType != "hotp" {
|
||||
return errors.New("invalid token type: must be 'totp' or 'hotp'")
|
||||
}
|
||||
|
||||
// 验证HOTP计数器
|
||||
if tokenType == "hotp" && counter < 0 {
|
||||
return errors.New("counter must be non-negative for HOTP tokens")
|
||||
}
|
||||
|
||||
// 验证周期
|
||||
if period <= 0 {
|
||||
return errors.New("period must be positive")
|
||||
}
|
||||
|
||||
// 验证位数
|
||||
if digits < 6 || digits > 8 {
|
||||
return errors.New("digits must be between 6 and 8")
|
||||
}
|
||||
|
||||
// 验证算法
|
||||
switch algo {
|
||||
case "SHA1", "SHA256", "SHA512":
|
||||
// 有效的算法
|
||||
default:
|
||||
return errors.New("invalid algorithm: must be SHA1, SHA256, or SHA512")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue