otpm/otp_api.go
2025-06-09 11:20:07 +08:00

417 lines
10 KiB
Go

package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
_ "github.com/lib/pq"
)
// 加密密钥(32字节AES-256)
var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
// encryptTokenSecret 加密令牌密钥
func encryptTokenSecret(secret string) (string, error) {
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
ciphertext := make([]byte, aes.BlockSize+len(secret))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// decryptTokenSecret 解密令牌密钥
func decryptTokenSecret(encrypted string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
if len(ciphertext) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)
return string(ciphertext), nil
}
// SaveRequest 保存请求的数据结构
type SaveRequest struct {
Tokens []TokenData `json:"tokens"`
UserID string `json:"userId"`
Timestamp int64 `json:"timestamp"`
}
// TokenData token数据结构
type TokenData struct {
ID string `json:"id"`
Issuer string `json:"issuer"`
Account string `json:"account"`
Secret string `json:"secret"`
Type string `json:"type"`
Counter int `json:"counter,omitempty"`
Period int `json:"period"`
Digits int `json:"digits"`
Algo string `json:"algo"`
}
// SaveResponse 保存响应的数据结构
type SaveResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data struct {
ID string `json:"id"`
} `json:"data"`
}
// RecoverResponse 恢复响应的数据结构
type RecoverResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data struct {
Tokens []TokenData `json:"tokens"`
Timestamp int64 `json:"timestamp"`
} `json:"data"`
}
var db *sql.DB
// InitDB 初始化数据库连接
func InitDB() error {
connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
var err error
db, err = sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("error opening database: %v", err)
}
if err = db.Ping(); err != nil {
return fmt.Errorf("error connecting to the database: %v", err)
}
return nil
}
// SaveHandler 保存token的接口处理函数
func SaveHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求
var req SaveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" {
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
if len(req.Tokens) == 0 {
sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
return
}
// 开始数据库事务
tx, err := db.Begin()
if err != nil {
log.Printf("Error starting transaction: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// 删除用户现有的tokens
_, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
if err != nil {
log.Printf("Error deleting existing tokens: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 插入新的tokens
stmt, err := tx.Prepare(`
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`)
if err != nil {
log.Printf("Error preparing statement: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer stmt.Close()
for _, token := range req.Tokens {
// 加密secret
encryptedSecret, err := encryptTokenSecret(token.Secret)
if err != nil {
log.Printf("Error encrypting token secret: %v", err)
sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
return
}
_, err = stmt.Exec(
token.ID,
req.UserID,
token.Issuer,
token.Account,
encryptedSecret,
token.Type,
token.Counter,
token.Period,
token.Digits,
token.Algo,
req.Timestamp,
)
if err != nil {
log.Printf("Error inserting token: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
}
// 提交事务
if err = tx.Commit(); err != nil {
log.Printf("Error committing transaction: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 返回成功响应
resp := SaveResponse{
Success: true,
Message: "Tokens saved successfully",
}
resp.Data.ID = req.UserID
sendJSONResponse(w, resp, http.StatusOK)
}
// RecoverHandler 恢复token的接口处理函数
func RecoverHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var req struct {
UserID string `json:"userId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证用户ID
if req.UserID == "" {
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
return
}
// 查询数据库
rows, err := db.Query(`
SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
FROM tokens
WHERE user_id = $1
ORDER BY timestamp DESC
`, req.UserID)
if err != nil {
log.Printf("Error querying database: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
defer rows.Close()
// 读取查询结果
var tokens []TokenData
var timestamp int64
for rows.Next() {
var token TokenData
var encryptedSecret string
err := rows.Scan(
&token.ID,
&token.Issuer,
&token.Account,
&encryptedSecret,
&token.Type,
&token.Counter,
&token.Period,
&token.Digits,
&token.Algo,
&timestamp,
)
if err != nil {
log.Printf("Error scanning row: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 解密secret
token.Secret, err = decryptTokenSecret(encryptedSecret)
if err != nil {
log.Printf("Error decrypting token secret: %v", err)
sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
return
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
log.Printf("Error iterating rows: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 返回响应
resp := RecoverResponse{
Success: true,
Message: "Tokens recovered successfully",
}
resp.Data.Tokens = tokens
resp.Data.Timestamp = timestamp
sendJSONResponse(w, resp, http.StatusOK)
}
// sendErrorResponse 发送错误响应
func sendErrorResponse(w http.ResponseWriter, message string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": message,
})
}
// DeleteTokenHandler 删除单个token的接口处理函数
func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
// 设置CORS头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
// 处理OPTIONS请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 检查请求方法
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求
var req struct {
UserID string `json:"userId"`
TokenID string `json:"tokenId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
log.Printf("Error decoding request: %v", err)
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证请求数据
if req.UserID == "" || req.TokenID == "" {
sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
return
}
// 执行删除操作
result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
if err != nil {
log.Printf("Error deleting token: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
// 检查是否真的删除了记录
rowsAffected, err := result.RowsAffected()
if err != nil {
log.Printf("Error getting rows affected: %v", err)
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
return
}
if rowsAffected == 0 {
sendErrorResponse(w, "Token not found", http.StatusNotFound)
return
}
// 返回成功响应
sendJSONResponse(w, map[string]interface{}{
"success": true,
"message": "Token deleted successfully",
}, http.StatusOK)
}
// sendJSONResponse 发送JSON响应
func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Error encoding response: %v", err)
sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
}
}