417 lines
10 KiB
Go
417 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
// 加密密钥(32字节AES-256)
|
|
var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
|
|
|
|
// encryptTokenSecret 加密令牌密钥
|
|
func encryptTokenSecret(secret string) (string, error) {
|
|
block, err := aes.NewCipher(encryptionKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ciphertext := make([]byte, aes.BlockSize+len(secret))
|
|
iv := ciphertext[:aes.BlockSize]
|
|
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
stream := cipher.NewCFBEncrypter(block, iv)
|
|
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
|
|
|
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
// decryptTokenSecret 解密令牌密钥
|
|
func decryptTokenSecret(encrypted string) (string, error) {
|
|
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
block, err := aes.NewCipher(encryptionKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if len(ciphertext) < aes.BlockSize {
|
|
return "", fmt.Errorf("ciphertext too short")
|
|
}
|
|
|
|
iv := ciphertext[:aes.BlockSize]
|
|
ciphertext = ciphertext[aes.BlockSize:]
|
|
|
|
stream := cipher.NewCFBDecrypter(block, iv)
|
|
stream.XORKeyStream(ciphertext, ciphertext)
|
|
|
|
return string(ciphertext), nil
|
|
}
|
|
|
|
// SaveRequest 保存请求的数据结构
|
|
type SaveRequest struct {
|
|
Tokens []TokenData `json:"tokens"`
|
|
UserID string `json:"userId"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
}
|
|
|
|
// TokenData token数据结构
|
|
type TokenData struct {
|
|
ID string `json:"id"`
|
|
Issuer string `json:"issuer"`
|
|
Account string `json:"account"`
|
|
Secret string `json:"secret"`
|
|
Type string `json:"type"`
|
|
Counter int `json:"counter,omitempty"`
|
|
Period int `json:"period"`
|
|
Digits int `json:"digits"`
|
|
Algo string `json:"algo"`
|
|
}
|
|
|
|
// SaveResponse 保存响应的数据结构
|
|
type SaveResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
Data struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
// RecoverResponse 恢复响应的数据结构
|
|
type RecoverResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
Data struct {
|
|
Tokens []TokenData `json:"tokens"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
var db *sql.DB
|
|
|
|
// InitDB 初始化数据库连接
|
|
func InitDB() error {
|
|
connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
|
|
var err error
|
|
db, err = sql.Open("postgres", connStr)
|
|
if err != nil {
|
|
return fmt.Errorf("error opening database: %v", err)
|
|
}
|
|
|
|
if err = db.Ping(); err != nil {
|
|
return fmt.Errorf("error connecting to the database: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveHandler 保存token的接口处理函数
|
|
func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
|
// 设置CORS头
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
|
|
// 处理OPTIONS请求
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
// 检查请求方法
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// 解析请求
|
|
var req SaveRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
log.Printf("Error decoding request: %v", err)
|
|
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 验证请求数据
|
|
if req.UserID == "" {
|
|
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if len(req.Tokens) == 0 {
|
|
sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 开始数据库事务
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
log.Printf("Error starting transaction: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// 删除用户现有的tokens
|
|
_, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
|
|
if err != nil {
|
|
log.Printf("Error deleting existing tokens: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 插入新的tokens
|
|
stmt, err := tx.Prepare(`
|
|
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
|
`)
|
|
if err != nil {
|
|
log.Printf("Error preparing statement: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
for _, token := range req.Tokens {
|
|
// 加密secret
|
|
encryptedSecret, err := encryptTokenSecret(token.Secret)
|
|
if err != nil {
|
|
log.Printf("Error encrypting token secret: %v", err)
|
|
sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
_, err = stmt.Exec(
|
|
token.ID,
|
|
req.UserID,
|
|
token.Issuer,
|
|
token.Account,
|
|
encryptedSecret,
|
|
token.Type,
|
|
token.Counter,
|
|
token.Period,
|
|
token.Digits,
|
|
token.Algo,
|
|
req.Timestamp,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Error inserting token: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// 提交事务
|
|
if err = tx.Commit(); err != nil {
|
|
log.Printf("Error committing transaction: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 返回成功响应
|
|
resp := SaveResponse{
|
|
Success: true,
|
|
Message: "Tokens saved successfully",
|
|
}
|
|
resp.Data.ID = req.UserID
|
|
|
|
sendJSONResponse(w, resp, http.StatusOK)
|
|
}
|
|
|
|
// RecoverHandler 恢复token的接口处理函数
|
|
func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
|
// 设置CORS头
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
|
|
// 处理OPTIONS请求
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
// 检查请求方法
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// 解析请求体
|
|
var req struct {
|
|
UserID string `json:"userId"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
log.Printf("Error decoding request: %v", err)
|
|
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 验证用户ID
|
|
if req.UserID == "" {
|
|
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 查询数据库
|
|
rows, err := db.Query(`
|
|
SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
|
|
FROM tokens
|
|
WHERE user_id = $1
|
|
ORDER BY timestamp DESC
|
|
`, req.UserID)
|
|
if err != nil {
|
|
log.Printf("Error querying database: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
// 读取查询结果
|
|
var tokens []TokenData
|
|
var timestamp int64
|
|
for rows.Next() {
|
|
var token TokenData
|
|
var encryptedSecret string
|
|
err := rows.Scan(
|
|
&token.ID,
|
|
&token.Issuer,
|
|
&token.Account,
|
|
&encryptedSecret,
|
|
&token.Type,
|
|
&token.Counter,
|
|
&token.Period,
|
|
&token.Digits,
|
|
&token.Algo,
|
|
×tamp,
|
|
)
|
|
if err != nil {
|
|
log.Printf("Error scanning row: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 解密secret
|
|
token.Secret, err = decryptTokenSecret(encryptedSecret)
|
|
if err != nil {
|
|
log.Printf("Error decrypting token secret: %v", err)
|
|
sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
tokens = append(tokens, token)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
log.Printf("Error iterating rows: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 返回响应
|
|
resp := RecoverResponse{
|
|
Success: true,
|
|
Message: "Tokens recovered successfully",
|
|
}
|
|
resp.Data.Tokens = tokens
|
|
resp.Data.Timestamp = timestamp
|
|
|
|
sendJSONResponse(w, resp, http.StatusOK)
|
|
}
|
|
|
|
// sendErrorResponse 发送错误响应
|
|
func sendErrorResponse(w http.ResponseWriter, message string, status int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"success": false,
|
|
"message": message,
|
|
})
|
|
}
|
|
|
|
// DeleteTokenHandler 删除单个token的接口处理函数
|
|
func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
|
|
// 设置CORS头
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
|
|
// 处理OPTIONS请求
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
// 检查请求方法
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// 解析请求
|
|
var req struct {
|
|
UserID string `json:"userId"`
|
|
TokenID string `json:"tokenId"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
log.Printf("Error decoding request: %v", err)
|
|
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 验证请求数据
|
|
if req.UserID == "" || req.TokenID == "" {
|
|
sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 执行删除操作
|
|
result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
|
|
if err != nil {
|
|
log.Printf("Error deleting token: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 检查是否真的删除了记录
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
log.Printf("Error getting rows affected: %v", err)
|
|
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
sendErrorResponse(w, "Token not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// 返回成功响应
|
|
sendJSONResponse(w, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Token deleted successfully",
|
|
}, http.StatusOK)
|
|
}
|
|
|
|
// sendJSONResponse 发送JSON响应
|
|
func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
if err := json.NewEncoder(w).Encode(data); err != nil {
|
|
log.Printf("Error encoding response: %v", err)
|
|
sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
|
}
|
|
}
|