This commit is contained in:
“xHuPo” 2025-06-17 14:46:09 +08:00
parent 01b8951dd5
commit 10ebc59ffb
17 changed files with 1087 additions and 238 deletions

View file

@ -1,18 +1,80 @@
package main
import (
"embed"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"regexp"
"auth"
"config"
"otpm/auth"
"otpm/config"
"otpm/db"
"github.com/gorilla/mux"
_ "github.com/lib/pq"
)
//go:embed init/sqlite/init.sql
//go:embed init/postgresql/init.sql
var initScripts embed.FS
// 全局数据库连接
var tokenDB *db.DB
// InitDB 初始化数据库连接
func InitDB(dbConfig config.DatabaseConfig) error {
var err error
// 创建数据库配置
dbCfg := db.Config{
Driver: dbConfig.Driver,
}
// 根据驱动类型设置配置
if dbConfig.Driver == "postgres" {
dbCfg.Host = dbConfig.Postgres.Host
dbCfg.Port = dbConfig.Postgres.Port
dbCfg.User = dbConfig.Postgres.User
dbCfg.Password = dbConfig.Postgres.Password
dbCfg.DBName = dbConfig.Postgres.DBName
dbCfg.SSLMode = dbConfig.Postgres.SSLMode
} else if dbConfig.Driver == "sqlite3" {
dbCfg.DBName = dbConfig.SQLite.Path
}
// 使用db包创建数据库连接
tokenDB, err = db.New(dbCfg)
if err != nil {
return fmt.Errorf("error opening database: %v", err)
}
// 读取并执行初始化脚本
var scriptPath string
if dbConfig.Driver == "sqlite3" {
scriptPath = "init/sqlite/init.sql"
} else if dbConfig.Driver == "postgres" {
scriptPath = "init/postgresql/init.sql"
} else {
return fmt.Errorf("unsupported database driver: %s", dbConfig.Driver)
}
script, err := initScripts.ReadFile(scriptPath)
if err != nil {
return fmt.Errorf("error reading init script: %v", err)
}
_, err = tokenDB.Exec(string(script))
if err != nil {
return fmt.Errorf("error executing init script: %v", err)
}
return nil
}
// Server represents the API server
type Server struct {
config *config.Config
@ -33,6 +95,7 @@ func NewServer(cfg *config.Config) *Server {
config: cfg,
router: mux.NewRouter(),
}
s.setupRoutes()
return s
}
@ -41,12 +104,15 @@ func NewServer(cfg *config.Config) *Server {
func (s *Server) setupRoutes() {
// 公开路由
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
s.router.HandleFunc("/auth/refresh", s.RefreshTokenHandler)
// 受保护路由(需要JWT)
authRouter := s.router.PathPrefix("").Subrouter()
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
authRouter.HandleFunc("/otp/delete", DeleteTokenHandler).Methods("POST")
authRouter.HandleFunc("/otp/clear_all", ClearAllHandler).Methods("POST")
// 添加CORS中间件
s.router.Use(s.corsMiddleware)
@ -56,22 +122,45 @@ func (s *Server) setupRoutes() {
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
referer := r.Header.Get("Referer")
userAgent := r.Header.Get("User-Agent")
// Check if the origin is allowed
allowed := false
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
if origin == allowedOrigin {
allowed = true
break
// 检查是否是微信小程序请求
isWechatMiniProgram := false
if userAgent != "" && (regexp.MustCompile(`MicroMessenger`).MatchString(userAgent) ||
regexp.MustCompile(`miniProgram`).MatchString(userAgent)) {
isWechatMiniProgram = true
}
// 从Referer中检查是否是微信小程序请求
if referer != "" && regexp.MustCompile(`^https://servicewechat\.com/`).MatchString(referer) {
isWechatMiniProgram = true
}
// 检查Origin是否允许
allowed := isWechatMiniProgram // 如果是微信小程序请求,默认允许
if !allowed && origin != "" {
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
if origin == allowedOrigin || allowedOrigin == "*" {
allowed = true
break
}
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
// 如果是微信小程序请求且没有Origin头使用通配符
if isWechatMiniProgram && origin == "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods",
joinStrings(s.config.CORS.AllowedMethods))
w.Header().Set("Access-Control-Allow-Headers",
joinStrings(s.config.CORS.AllowedHeaders))
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if r.Method == "OPTIONS" {
@ -96,13 +185,14 @@ func joinStrings(slice []string) string {
}
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 读取请求体
body, err := ioutil.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
@ -127,7 +217,7 @@ func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
}
defer resp.Body.Close()
body, err = ioutil.ReadAll(resp.Body)
body, err = io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Wechat service error", http.StatusInternalServerError)
return
@ -144,25 +234,91 @@ func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
return
}
// 生成JWT token
token, err := s.generateSessionToken(wechatResp.OpenID)
// 生成访问令牌和刷新令牌
accessToken, refreshToken, err := s.generateSessionTokens(wechatResp.OpenID)
if err != nil {
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
return
}
// 返回响应
response := map[string]interface{}{
"token": token,
"openid": wechatResp.OpenID,
"access_token": accessToken,
"refresh_token": refreshToken,
"openid": wechatResp.OpenID,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
func (s *Server) generateSessionToken(openid string) (string, error) {
return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
func (s *Server) generateSessionTokens(openid string) (accessToken string, refreshToken string, err error) {
accessToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
if err != nil {
return "", "", fmt.Errorf("failed to generate access token: %v", err)
}
refreshToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.RefreshTokenExpiry)
if err != nil {
return "", "", fmt.Errorf("failed to generate refresh token: %v", err)
}
return accessToken, refreshToken, nil
}
func (s *Server) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 读取请求体
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
var req struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// 验证刷新令牌
claims, err := auth.ValidateToken(req.RefreshToken, s.config.Security.JWTSigningKey)
if err != nil {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
// 从刷新令牌中获取 openid
openid := claims.UserID
if openid == "" {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
// 生成新的访问令牌和刷新令牌
accessToken, refreshToken, err := s.generateSessionTokens(openid)
if err != nil {
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
return
}
// 返回响应
response := map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshToken,
"openid": openid,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
// Start starts the HTTP server
@ -181,8 +337,12 @@ func (s *Server) Start() error {
}
func main() {
// 定义命令行参数
configPath := flag.String("config", "config", "Path to configuration file (without extension)")
flag.Parse()
// 加载配置
cfg, err := config.LoadConfig("config")
cfg, err := config.LoadConfig(*configPath)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
@ -194,6 +354,13 @@ func main() {
}
log.Println("Database connection established successfully")
// 初始化API
log.Println("Initializing API...")
if err := InitAPI(cfg); err != nil {
log.Fatalf("Failed to initialize API: %v", err)
}
log.Println("API initialized successfully")
// 创建并启动服务器
server := NewServer(cfg)
if err := server.Start(); err != nil {

View file

@ -96,6 +96,28 @@ func GetUserIDFromContext(ctx context.Context) (string, error) {
return userID, nil
}
// ValidateToken validates a JWT token and returns the claims
func ValidateToken(tokenString string, signingKey string) (*Claims, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(signingKey), nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
return claims, nil
}
// RequireAuth is a middleware that ensures a valid JWT token is present
func RequireAuth(signingKey string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {

73
auth/refresh.go Normal file
View file

@ -0,0 +1,73 @@
package auth
import (
"encoding/json"
"net/http"
"time"
"github.com/golang-jwt/jwt/v5"
)
// RefreshRequest represents the request body for token refresh
type RefreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
// RefreshResponse represents the response body for token refresh
type RefreshResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// HandleRefresh handles the token refresh request
func HandleRefresh(w http.ResponseWriter, r *http.Request, signingKey string, accessExpiry, refreshExpiry time.Duration) {
// Only accept POST requests
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse request body
var req RefreshRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate refresh token
claims := &Claims{}
token, err := jwt.ParseWithClaims(req.RefreshToken, claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(signingKey), nil
})
if err != nil || !token.Valid {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
// Generate new access token
accessToken, err := GenerateToken(claims.UserID, signingKey, accessExpiry)
if err != nil {
http.Error(w, "Failed to generate access token", http.StatusInternalServerError)
return
}
// Generate new refresh token
refreshToken, err := GenerateToken(claims.UserID, signingKey, refreshExpiry)
if err != nil {
http.Error(w, "Failed to generate refresh token", http.StatusInternalServerError)
return
}
// Return new tokens
resp := RefreshResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}

View file

@ -18,21 +18,22 @@ database:
# Security Configuration
security:
encryption_key: "your-32-byte-encryption-key-here"
jwt_signing_key: "your-jwt-signing-key-here"
encryption_key: "12345678901234567890123456789012"
jwt_signing_key: "jwt_secret_key_for_authentication_12345"
token_expiry: 24h
refresh_token_expiry: 168h # 7 days
# WeChat Configuration
wechat:
app_id: "YOUR_APPID"
app_secret: "YOUR_APPSECRET"
app_id: "wx57d1033974eb5250"
app_secret: "be494c2a81df685a40b9a74e1736b15d"
# CORS Configuration
cors:
allowed_origins:
- "http://localhost:8080"
- "https://yourdomain.com"
- "https://servicewechat.com/wx57d1033974eb5250"
allowed_methods:
- "GET"
- "POST"

View file

@ -9,54 +9,54 @@ import (
// Config holds all configuration for our application
type Config struct {
Server ServerConfig
Database DatabaseConfig
Security SecurityConfig
CORS CORSConfig
Wechat WechatConfig
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Security SecurityConfig `mapstructure:"security"`
CORS CORSConfig `mapstructure:"cors"`
Wechat WechatConfig `mapstructure:"wechat"`
}
// ServerConfig holds all server related configuration
type ServerConfig struct {
Port int
Timeout time.Duration
Port int `mapstructure:"port"`
Timeout time.Duration `mapstructure:"timeout"`
}
// DatabaseConfig holds all database related configuration
type DatabaseConfig struct {
Driver string
SQLite SQLiteConfig
Postgres PostgresConfig
Driver string `mapstructure:"driver"`
SQLite SQLiteConfig `mapstructure:"sqlite"`
Postgres PostgresConfig `mapstructure:"postgres"`
}
// SQLiteConfig holds SQLite specific configuration
type SQLiteConfig struct {
Path string
Path string `mapstructure:"path"`
}
// PostgresConfig holds PostgreSQL specific configuration
type PostgresConfig struct {
Host string
Port int
User string
Password string
DBName string
SSLMode string
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
}
// SecurityConfig holds all security related configuration
type SecurityConfig struct {
EncryptionKey string
JWTSigningKey string
TokenExpiry time.Duration
RefreshTokenExpiry time.Duration
EncryptionKey string `mapstructure:"encryption_key"`
JWTSigningKey string `mapstructure:"jwt_signing_key"`
TokenExpiry time.Duration `mapstructure:"token_expiry"`
RefreshTokenExpiry time.Duration `mapstructure:"refresh_token_expiry"`
}
// CORSConfig holds CORS related configuration
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowedMethods []string `mapstructure:"allowed_methods"`
AllowedHeaders []string `mapstructure:"allowed_headers"`
}
// WechatConfig holds WeChat related configuration
@ -69,10 +69,16 @@ type WechatConfig struct {
func LoadConfig(configPath string) (*Config, error) {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(configPath)
v.AddConfigPath(".")
// 检查配置路径是否包含扩展名
if len(configPath) > 5 && (configPath[len(configPath)-5:] == ".yaml" || configPath[len(configPath)-4:] == ".yml") {
// 如果包含扩展名,直接使用完整路径
v.SetConfigFile(configPath)
} else {
// 否则按照传统方式处理
v.SetConfigName(configPath)
v.SetConfigType("yaml")
v.AddConfigPath(".")
}
// Read environment variables
v.AutomaticEnv()

44
config/config.yaml Normal file
View file

@ -0,0 +1,44 @@
# Server Configuration
server:
port: 8080
timeout: 30s
# Database Configuration
database:
driver: "sqlite3" # or "postgres"
sqlite:
path: "./data.db"
postgres:
host: "localhost"
port: 5432
user: "postgres"
password: "password"
dbname: "otpdb"
sslmode: "disable"
# Security Configuration
security:
encryption_key: "12345678901234567890123456789012"
jwt_signing_key: "jwt_secret_key_for_authentication_12345"
token_expiry: 24h
refresh_token_expiry: 168h # 7 days
# WeChat Configuration
wechat:
app_id: "wx57d1033974eb5250"
app_secret: "be494c2a81df685a40b9a74e1736b15d"
# CORS Configuration
cors:
allowed_origins:
- "http://localhost:8080"
- "https://yourdomain.com"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "Authorization"
- "Content-Type"

68
crypto/token.go Normal file
View file

@ -0,0 +1,68 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
// TokenEncryptor 用于加密和解密令牌的结构体
type TokenEncryptor struct {
key []byte
}
// NewTokenEncryptor 创建一个新的TokenEncryptor实例
func NewTokenEncryptor(key []byte) (*TokenEncryptor, error) {
if len(key) != 32 {
return nil, fmt.Errorf("key must be 32 bytes")
}
return &TokenEncryptor{key: key}, nil
}
// EncryptTokenSecret 加密令牌密钥
func (te *TokenEncryptor) EncryptTokenSecret(secret string) (string, error) {
block, err := aes.NewCipher(te.key)
if err != nil {
return "", fmt.Errorf("error creating cipher: %v", err)
}
plaintext := []byte(secret)
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", fmt.Errorf("error generating IV: %v", err)
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptTokenSecret 解密令牌密钥
func (te *TokenEncryptor) DecryptTokenSecret(encrypted string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", fmt.Errorf("error decoding base64: %v", err)
}
block, err := aes.NewCipher(te.key)
if err != nil {
return "", fmt.Errorf("error creating cipher: %v", err)
}
if len(ciphertext) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)
return string(ciphertext), nil
}

52
db/db.go Normal file
View file

@ -0,0 +1,52 @@
package db
import (
"database/sql"
"fmt"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
// Config 数据库配置
type Config struct {
Driver string
Host string
Port int
User string
Password string
DBName string
SSLMode string
}
// DB 封装数据库操作
type DB struct {
*sql.DB
}
// New 创建数据库连接
func New(cfg Config) (*DB, error) {
var dsn string
switch cfg.Driver {
case "postgres":
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode)
case "sqlite3":
dsn = cfg.DBName
default:
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
}
db, err := sql.Open(cfg.Driver, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %v", err)
}
// 测试连接
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping database: %v", err)
}
return &DB{db}, nil
}

224
db/token.go Normal file
View file

@ -0,0 +1,224 @@
package db
import (
"database/sql"
"errors"
"time"
)
// 令牌类型常量
const (
TokenTypeHOTP = "hotp"
TokenTypeTOTP = "totp"
)
// Token 表示一个 OTP 令牌
type Token struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Issuer string `json:"issuer"`
Account string `json:"account"`
Secret string `json:"secret"`
Type string `json:"type"`
Counter *int `json:"counter,omitempty"`
Period int `json:"period"`
Digits int `json:"digits"`
Algorithm string `json:"algorithm"`
Timestamp int64 `json:"timestamp"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
var (
ErrTokenNotFound = errors.New("token not found")
ErrInvalidToken = errors.New("invalid token")
)
// SaveTokens 保存用户的令牌列表
func (db *DB) SaveTokens(userID string, tokens []Token) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// 准备插入语句
stmt, err := tx.Prepare(`
INSERT INTO tokens (
id, user_id, issuer, account, secret, type, counter,
period, digits, algorithm, timestamp, created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
) ON CONFLICT (id, user_id) DO UPDATE SET
issuer = EXCLUDED.issuer,
account = EXCLUDED.account,
secret = EXCLUDED.secret,
type = EXCLUDED.type,
counter = EXCLUDED.counter,
period = EXCLUDED.period,
digits = EXCLUDED.digits,
algorithm = EXCLUDED.algorithm,
timestamp = EXCLUDED.timestamp,
updated_at = EXCLUDED.updated_at
`)
if err != nil {
return err
}
defer stmt.Close()
now := time.Now()
for _, token := range tokens {
// 确保algorithm字段有默认值
algorithm := token.Algorithm
if algorithm == "" {
algorithm = "SHA1"
}
_, err = stmt.Exec(
token.ID,
userID,
token.Issuer,
token.Account,
token.Secret,
token.Type,
token.Counter,
token.Period,
token.Digits,
algorithm,
token.Timestamp,
now,
now,
)
if err != nil {
return err
}
}
return tx.Commit()
}
// GetTokensByUserID 获取用户的所有令牌
func (db *DB) GetTokensByUserID(userID string) ([]Token, error) {
rows, err := db.Query(`
SELECT
id, user_id, issuer, account, secret, type, counter,
period, digits, algorithm, timestamp, created_at, updated_at
FROM tokens
WHERE user_id = $1
ORDER BY timestamp DESC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var tokens []Token
for rows.Next() {
var token Token
err := rows.Scan(
&token.ID,
&token.UserID,
&token.Issuer,
&token.Account,
&token.Secret,
&token.Type,
&token.Counter,
&token.Period,
&token.Digits,
&token.Algorithm,
&token.Timestamp,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
return nil, err
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
return nil, err
}
return tokens, nil
}
// DeleteToken 删除指定的令牌
func (db *DB) DeleteToken(userID, tokenID string) error {
result, err := db.Exec(`
DELETE FROM tokens
WHERE user_id = $1 AND id = $2
`, userID, tokenID)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrTokenNotFound
}
return nil
}
// UpdateTokenCounter 更新 HOTP 令牌的计数器
func (db *DB) UpdateTokenCounter(userID, tokenID string, counter int) error {
result, err := db.Exec(`
UPDATE tokens
SET counter = $1, updated_at = $2
WHERE user_id = $3 AND id = $4 AND type = $5
`, counter, time.Now(), userID, tokenID, TokenTypeHOTP)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrTokenNotFound
}
return nil
}
// GetTokenByID 获取指定的令牌
func (db *DB) GetTokenByID(userID, tokenID string) (*Token, error) {
var token Token
err := db.QueryRow(`
SELECT
id, user_id, issuer, account, secret, type, counter,
period, digits, algorithm, timestamp, created_at, updated_at
FROM tokens
WHERE user_id = $1 AND id = $2
`, userID, tokenID).Scan(
&token.ID,
&token.UserID,
&token.Issuer,
&token.Account,
&token.Secret,
&token.Type,
&token.Counter,
&token.Period,
&token.Digits,
&token.Algorithm,
&token.Timestamp,
&token.CreatedAt,
&token.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, ErrTokenNotFound
}
if err != nil {
return nil, err
}
return &token, nil
}

28
go.mod Normal file
View file

@ -0,0 +1,28 @@
module otpm
go 1.21.1
require (
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/gorilla/mux v1.8.1
github.com/lib/pq v1.10.9
github.com/spf13/viper v1.20.1
)
require (
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

60
go.sum Normal file
View file

@ -0,0 +1,60 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS tokens (
counter INTEGER, -- HOTP计数器可选
period INTEGER NOT NULL, -- TOTP周期
digits INTEGER NOT NULL, -- 验证码位数
algo VARCHAR(10) NOT NULL, -- 使用的哈希算法
algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1', -- 使用的哈希算法
timestamp BIGINT NOT NULL, -- 最后更新时间戳
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
@ -45,7 +45,7 @@ COMMENT ON COLUMN tokens.type IS '令牌类型totp/hotp';
COMMENT ON COLUMN tokens.counter IS 'HOTP计数器可选';
COMMENT ON COLUMN tokens.period IS 'TOTP周期';
COMMENT ON COLUMN tokens.digits IS '验证码位数';
COMMENT ON COLUMN tokens.algo IS '使用的哈希算法';
COMMENT ON COLUMN tokens.algorithm IS '使用的哈希算法';
COMMENT ON COLUMN tokens.timestamp IS '最后更新时间戳';
COMMENT ON COLUMN tokens.created_at IS '创建时间';
COMMENT ON COLUMN tokens.updated_at IS '最后更新时间';

View file

@ -9,22 +9,22 @@ PRAGMA foreign_keys = ON;
-- 创建tokens表
CREATE TABLE IF NOT EXISTS tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
issuer TEXT NOT NULL,
account TEXT NOT NULL,
secret TEXT NOT NULL CHECK (length(secret) >= 16 AND secret REGEXP '^[A-Z2-7]+=*$'),
type TEXT NOT NULL CHECK (type IN ('HOTP', 'TOTP')),
secret TEXT NOT NULL CHECK (length(secret) >= 16),
type TEXT NOT NULL CHECK (type IN ('hotp', 'totp')),
counter INTEGER CHECK (
(type = 'HOTP' AND counter >= 0) OR
(type = 'TOTP' AND counter IS NULL)
(type = 'hotp' AND counter >= 0) OR
(type = 'totp' AND counter IS NULL)
),
period INTEGER DEFAULT 30 CHECK (
(type = 'TOTP' AND period >= 30) OR
(type = 'HOTP' AND period IS NULL)
(type = 'totp' AND period >= 30) OR
(type = 'hotp' AND period IS NULL)
),
digits INTEGER NOT NULL DEFAULT 6 CHECK (digits IN (6, 8)),
algo TEXT NOT NULL DEFAULT 'SHA1' CHECK (algo IN ('SHA1', 'SHA256', 'SHA512')),
algorithm TEXT NOT NULL DEFAULT 'SHA1' CHECK (algorithm IN ('SHA1', 'SHA256', 'SHA512')),
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
updated_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
UNIQUE(user_id, issuer, account)
@ -33,16 +33,16 @@ CREATE TABLE IF NOT EXISTS tokens (
-- 基本索引
CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
CREATE INDEX IF NOT EXISTS idx_tokens_lookup ON tokens(user_id, issuer, account);
CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'HOTP';
CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'TOTP';
CREATE INDEX IF NOT EXISTS idx_tokens_hotp ON tokens(user_id) WHERE type = 'hotp';
CREATE INDEX IF NOT EXISTS idx_tokens_totp ON tokens(user_id) WHERE type = 'totp';
-- 简化统计视图
CREATE VIEW IF NOT EXISTS v_token_stats AS
SELECT
user_id,
COUNT(*) as total_tokens,
SUM(type = 'HOTP') as hotp_count,
SUM(type = 'TOTP') as totp_count
SUM(type = 'hotp') as hotp_count,
SUM(type = 'totp') as totp_count
FROM tokens
GROUP BY user_id;

View file

@ -1,65 +1,36 @@
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"otpm/config"
"otpm/crypto"
"otpm/db"
"otpm/utils"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
// 加密密钥(32字节AES-256)
var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
var (
tokenEncryptor *crypto.TokenEncryptor
)
// encryptTokenSecret 加密令牌密钥
func encryptTokenSecret(secret string) (string, error) {
block, err := aes.NewCipher(encryptionKey)
// InitAPI 初始化API相关配置
func InitAPI(cfg *config.Config) error {
var err error
// 初始化token加密器
tokenEncryptor, err = crypto.NewTokenEncryptor([]byte(cfg.Security.EncryptionKey))
if err != nil {
return "", err
return fmt.Errorf("failed to initialize token encryptor: %v", err)
}
ciphertext := make([]byte, aes.BlockSize+len(secret))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// decryptTokenSecret 解密令牌密钥
func decryptTokenSecret(encrypted string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
if len(ciphertext) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)
return string(ciphertext), nil
return nil
}
// SaveRequest 保存请求的数据结构
@ -71,15 +42,15 @@ type SaveRequest struct {
// TokenData token数据结构
type TokenData struct {
ID string `json:"id"`
Issuer string `json:"issuer"`
Account string `json:"account"`
Secret string `json:"secret"`
Type string `json:"type"`
Counter int `json:"counter,omitempty"`
Period int `json:"period"`
Digits int `json:"digits"`
Algo string `json:"algo"`
ID string `json:"id"`
Issuer string `json:"issuer"`
Account string `json:"account"`
Secret string `json:"secret"`
Type string `json:"type"`
Counter *int `json:"counter,omitempty"`
Period int `json:"period"`
Digits int `json:"digits"`
Algorithm string `json:"algorithm"`
}
// SaveResponse 保存响应的数据结构
@ -101,37 +72,10 @@ type RecoverResponse struct {
} `json:"data"`
}
var db *sql.DB
// InitDB 初始化数据库连接
func InitDB() error {
connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
var err error
db, err = sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("error opening database: %v", err)
}
if err = db.Ping(); err != nil {
return fmt.Errorf("error connecting to the database: %v", err)
}
return nil
}
// 使用api_server.go中定义的全局数据库连接
// SaveHandler 保存token的接口处理函数
func SaveHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -142,26 +86,26 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
var req SaveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" {
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
utils.SendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
if len(req.Tokens) == 0 {
sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
utils.SendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
return
}
// 开始数据库事务
tx, err := db.Begin()
tx, err := tokenDB.Begin()
if err != nil {
log.Printf("Error starting transaction: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer tx.Rollback()
@ -170,47 +114,68 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
_, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
if err != nil {
log.Printf("Error deleting existing tokens: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 插入新的tokens
stmt, err := tx.Prepare(`
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algorithm)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`)
if err != nil {
log.Printf("Error preparing statement: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer stmt.Close()
for _, token := range req.Tokens {
// 加密secret
encryptedSecret, err := encryptTokenSecret(token.Secret)
encryptedSecret, err := tokenEncryptor.EncryptTokenSecret(token.Secret)
if err != nil {
log.Printf("Error encrypting token secret: %v", err)
sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
return
}
// 确定令牌类型
tokenType := token.Type
if tokenType != db.TokenTypeHOTP && tokenType != db.TokenTypeTOTP {
log.Printf("Invalid token type: %s", tokenType)
utils.SendErrorResponse(w, "Invalid token type", http.StatusBadRequest)
return
}
// 对于TOTP类型的令牌counter必须为NULL
var counterValue interface{}
if tokenType == db.TokenTypeTOTP {
counterValue = nil
} else {
counterValue = token.Counter
}
// 确保algo字段有默认值
algorithm := token.Algorithm
if algorithm == "" {
algorithm = "SHA1"
}
_, err = stmt.Exec(
token.ID,
req.UserID,
token.Issuer,
token.Account,
encryptedSecret,
token.Type,
token.Counter,
tokenType,
counterValue,
token.Period,
token.Digits,
token.Algo,
req.Timestamp,
algorithm,
)
if err != nil {
log.Printf("Error inserting token: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
}
@ -218,7 +183,7 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
// 提交事务
if err = tx.Commit(); err != nil {
log.Printf("Error committing transaction: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
@ -229,22 +194,11 @@ func SaveHandler(w http.ResponseWriter, r *http.Request) {
}
resp.Data.ID = req.UserID
sendJSONResponse(w, resp, http.StatusOK)
utils.SendJSONResponse(w, resp)
}
// RecoverHandler 恢复token的接口处理函数
func RecoverHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -257,33 +211,33 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证用户ID
if req.UserID == "" {
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
utils.SendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
// 查询数据库
rows, err := db.Query(`
SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
rows, err := tokenDB.Query(`
SELECT id, issuer, account, secret, type, counter, period, digits, algorithm, updated_at
FROM tokens
WHERE user_id = $1
ORDER BY timestamp DESC
ORDER BY updated_at DESC
`, req.UserID)
if err != nil {
log.Printf("Error querying database: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer rows.Close()
// 读取查询结果
var tokens []TokenData
var timestamp int64
var updatedAt string
for rows.Next() {
var token TokenData
var encryptedSecret string
@ -296,20 +250,20 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
&token.Counter,
&token.Period,
&token.Digits,
&token.Algo,
&timestamp,
&token.Algorithm,
&updatedAt,
)
if err != nil {
log.Printf("Error scanning row: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 解密secret
token.Secret, err = decryptTokenSecret(encryptedSecret)
token.Secret, err = tokenEncryptor.DecryptTokenSecret(encryptedSecret)
if err != nil {
log.Printf("Error decrypting token secret: %v", err)
sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
return
}
tokens = append(tokens, token)
@ -317,7 +271,22 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
if err = rows.Err(); err != nil {
log.Printf("Error iterating rows: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 处理无数据情况
if len(tokens) == 0 {
resp := RecoverResponse{
Success: true,
Message: "No tokens found in cloud",
}
resp.Data.Tokens = []TokenData{}
resp.Data.Timestamp = time.Now().Unix()
// 返回404状态码但保持success:true
w.WriteHeader(http.StatusNotFound)
utils.SendJSONResponse(w, resp)
return
}
@ -327,34 +296,78 @@ func RecoverHandler(w http.ResponseWriter, r *http.Request) {
Message: "Tokens recovered successfully",
}
resp.Data.Tokens = tokens
// 将updatedAt转换为时间戳
var timestamp int64
// 首先尝试将updatedAt解析为整数时间戳
if ts, err := utils.ParseInt64(updatedAt); err == nil {
timestamp = ts
} else if t, err := time.Parse(time.RFC3339, updatedAt); err == nil {
// 尝试解析为RFC3339格式
timestamp = t.Unix()
} else if t, err := time.Parse("2006-01-02 15:04:05", updatedAt); err == nil {
// 尝试解析为标准日期时间格式
timestamp = t.Unix()
} else {
// 如果所有解析方法都失败,使用当前时间
timestamp = time.Now().Unix()
log.Printf("Error parsing time string: %v", err)
}
resp.Data.Timestamp = timestamp
sendJSONResponse(w, resp, http.StatusOK)
utils.SendJSONResponse(w, resp)
}
// sendErrorResponse 发送错误响应
func sendErrorResponse(w http.ResponseWriter, message string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": message,
// ClearAllHandler 删除用户所有token数据的接口处理函数
func ClearAllHandler(w http.ResponseWriter, r *http.Request) {
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求
var req struct {
UserID string `json:"userId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" {
utils.SendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
// 执行删除操作
result, err := tokenDB.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
if err != nil {
log.Printf("Error deleting tokens: %v", err)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 检查是否真的删除了记录
rowsAffected, err := result.RowsAffected()
if err != nil {
log.Printf("Error getting rows affected: %v", err)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 返回成功响应
utils.SendJSONResponse(w, map[string]interface{}{
"success": true,
"message": fmt.Sprintf("Deleted %d tokens successfully", rowsAffected),
})
}
// DeleteTokenHandler 删除单个token的接口处理函数
func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -368,21 +381,21 @@ func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
utils.SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" || req.TokenID == "" {
sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
utils.SendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
return
}
// 执行删除操作
result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
result, err := tokenDB.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
if err != nil {
log.Printf("Error deleting token: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
@ -390,28 +403,18 @@ func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
rowsAffected, err := result.RowsAffected()
if err != nil {
log.Printf("Error getting rows affected: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
utils.SendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
if rowsAffected == 0 {
sendErrorResponse(w, "Token not found", http.StatusNotFound)
utils.SendErrorResponse(w, "Token not found", http.StatusNotFound)
return
}
// 返回成功响应
sendJSONResponse(w, map[string]interface{}{
utils.SendJSONResponse(w, map[string]interface{}{
"success": true,
"message": "Token deleted successfully",
}, http.StatusOK)
}
// sendJSONResponse 发送JSON响应
func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Error encoding response: %v", err)
sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
}
})
}

View file

@ -15,17 +15,27 @@ func TestSaveHandler(t *testing.T) {
defer srv.Close()
// 准备测试数据
counter := 0 // 创建一个int变量
testData := SaveRequest{
UserID: "test_user_123",
Tokens: []TokenData{
{
Issuer: "TestOrg",
Account: "user@test.com",
Secret: "JBSWY3DPEHPK3PXP",
Type: "totp",
Period: 30,
Digits: 6,
Algo: "SHA1",
Issuer: "TestOrg",
Account: "user@test.com",
Secret: "JBSWY3DPEHPK3PXP",
Type: "totp",
Period: 30,
Digits: 6,
Algorithm: "SHA1",
},
{
Issuer: "TestOrgHOTP",
Account: "user@test.com",
Secret: "JBSWY3DPEHPK3PXP",
Type: "hotp",
Counter: &counter, // 使用指针
Digits: 6,
Algorithm: "SHA1",
},
},
}
@ -105,7 +115,7 @@ func TestRecoverHandler(t *testing.T) {
if recoverResp.Message != "Tokens recovered successfully" {
t.Errorf("Expected message to be 'Tokens recovered successfully', got '%s'\n", recoverResp.Message)
}
if len(recoverResp.Tokens) != 1 {
t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Tokens))
if len(recoverResp.Data.Tokens) != 1 {
t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Data.Tokens))
}
}

34
utils/http.go Normal file
View file

@ -0,0 +1,34 @@
package utils
import (
"encoding/json"
"log"
"net/http"
)
// SendErrorResponse 发送错误响应
func SendErrorResponse(w http.ResponseWriter, message string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": message,
})
}
// SendJSONResponse 发送JSON响应
func SendJSONResponse(w http.ResponseWriter, data interface{}, status ...int) {
w.Header().Set("Content-Type", "application/json")
// 默认状态码为200 OK
statusCode := http.StatusOK
if len(status) > 0 {
statusCode = status[0]
}
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Error encoding response: %v", err)
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
}
}

57
utils/validation.go Normal file
View file

@ -0,0 +1,57 @@
package utils
import (
"errors"
"strconv"
)
// ParseInt64 将字符串解析为int64类型
func ParseInt64(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
// ValidateToken 验证令牌数据的有效性
func ValidateToken(id, issuer, account, secret, tokenType, algo string, counter, period, digits int) error {
if id == "" {
return errors.New("token ID is required")
}
if issuer == "" {
return errors.New("issuer is required")
}
if account == "" {
return errors.New("account is required")
}
if secret == "" {
return errors.New("secret is required")
}
// 验证令牌类型
if tokenType != "totp" && tokenType != "hotp" {
return errors.New("invalid token type: must be 'totp' or 'hotp'")
}
// 验证HOTP计数器
if tokenType == "hotp" && counter < 0 {
return errors.New("counter must be non-negative for HOTP tokens")
}
// 验证周期
if period <= 0 {
return errors.New("period must be positive")
}
// 验证位数
if digits < 6 || digits > 8 {
return errors.New("digits must be between 6 and 8")
}
// 验证算法
switch algo {
case "SHA1", "SHA256", "SHA512":
// 有效的算法
default:
return errors.New("invalid algorithm: must be SHA1, SHA256, or SHA512")
}
return nil
}