add branch v1
This commit is contained in:
parent
5d370e1077
commit
01b8951dd5
53 changed files with 1079 additions and 6481 deletions
417
otp_api.go
Normal file
417
otp_api.go
Normal file
|
@ -0,0 +1,417 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// 加密密钥(32字节AES-256)
|
||||
var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取
|
||||
|
||||
// encryptTokenSecret 加密令牌密钥
|
||||
func encryptTokenSecret(secret string) (string, error) {
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := make([]byte, aes.BlockSize+len(secret))
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret))
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// decryptTokenSecret 解密令牌密钥
|
||||
func decryptTokenSecret(encrypted string) (string, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
ciphertext = ciphertext[aes.BlockSize:]
|
||||
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext, ciphertext)
|
||||
|
||||
return string(ciphertext), nil
|
||||
}
|
||||
|
||||
// SaveRequest 保存请求的数据结构
|
||||
type SaveRequest struct {
|
||||
Tokens []TokenData `json:"tokens"`
|
||||
UserID string `json:"userId"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// TokenData token数据结构
|
||||
type TokenData struct {
|
||||
ID string `json:"id"`
|
||||
Issuer string `json:"issuer"`
|
||||
Account string `json:"account"`
|
||||
Secret string `json:"secret"`
|
||||
Type string `json:"type"`
|
||||
Counter int `json:"counter,omitempty"`
|
||||
Period int `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Algo string `json:"algo"`
|
||||
}
|
||||
|
||||
// SaveResponse 保存响应的数据结构
|
||||
type SaveResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// RecoverResponse 恢复响应的数据结构
|
||||
type RecoverResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
Tokens []TokenData `json:"tokens"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB() error {
|
||||
connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable"
|
||||
var err error
|
||||
db, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening database: %v", err)
|
||||
}
|
||||
|
||||
if err = db.Ping(); err != nil {
|
||||
return fmt.Errorf("error connecting to the database: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveHandler 保存token的接口处理函数
|
||||
func SaveHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
// 处理OPTIONS请求
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req SaveRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求数据
|
||||
if req.UserID == "" {
|
||||
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Tokens) == 0 {
|
||||
sendErrorResponse(w, "No tokens provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 开始数据库事务
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Printf("Error starting transaction: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 删除用户现有的tokens
|
||||
_, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("Error deleting existing tokens: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 插入新的tokens
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`)
|
||||
if err != nil {
|
||||
log.Printf("Error preparing statement: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, token := range req.Tokens {
|
||||
// 加密secret
|
||||
encryptedSecret, err := encryptTokenSecret(token.Secret)
|
||||
if err != nil {
|
||||
log.Printf("Error encrypting token secret: %v", err)
|
||||
sendErrorResponse(w, "Encryption error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(
|
||||
token.ID,
|
||||
req.UserID,
|
||||
token.Issuer,
|
||||
token.Account,
|
||||
encryptedSecret,
|
||||
token.Type,
|
||||
token.Counter,
|
||||
token.Period,
|
||||
token.Digits,
|
||||
token.Algo,
|
||||
req.Timestamp,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Error inserting token: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err = tx.Commit(); err != nil {
|
||||
log.Printf("Error committing transaction: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
resp := SaveResponse{
|
||||
Success: true,
|
||||
Message: "Tokens saved successfully",
|
||||
}
|
||||
resp.Data.ID = req.UserID
|
||||
|
||||
sendJSONResponse(w, resp, http.StatusOK)
|
||||
}
|
||||
|
||||
// RecoverHandler 恢复token的接口处理函数
|
||||
func RecoverHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
// 处理OPTIONS请求
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证用户ID
|
||||
if req.UserID == "" {
|
||||
sendErrorResponse(w, "Missing user ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 查询数据库
|
||||
rows, err := db.Query(`
|
||||
SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp
|
||||
FROM tokens
|
||||
WHERE user_id = $1
|
||||
ORDER BY timestamp DESC
|
||||
`, req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("Error querying database: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 读取查询结果
|
||||
var tokens []TokenData
|
||||
var timestamp int64
|
||||
for rows.Next() {
|
||||
var token TokenData
|
||||
var encryptedSecret string
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&encryptedSecret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algo,
|
||||
×tamp,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Error scanning row: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 解密secret
|
||||
token.Secret, err = decryptTokenSecret(encryptedSecret)
|
||||
if err != nil {
|
||||
log.Printf("Error decrypting token secret: %v", err)
|
||||
sendErrorResponse(w, "Decryption error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
log.Printf("Error iterating rows: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
resp := RecoverResponse{
|
||||
Success: true,
|
||||
Message: "Tokens recovered successfully",
|
||||
}
|
||||
resp.Data.Tokens = tokens
|
||||
resp.Data.Timestamp = timestamp
|
||||
|
||||
sendJSONResponse(w, resp, http.StatusOK)
|
||||
}
|
||||
|
||||
// sendErrorResponse 发送错误响应
|
||||
func sendErrorResponse(w http.ResponseWriter, message string, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": false,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteTokenHandler 删除单个token的接口处理函数
|
||||
func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
// 处理OPTIONS请求
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求方法
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req struct {
|
||||
UserID string `json:"userId"`
|
||||
TokenID string `json:"tokenId"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("Error decoding request: %v", err)
|
||||
sendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求数据
|
||||
if req.UserID == "" || req.TokenID == "" {
|
||||
sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行删除操作
|
||||
result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID)
|
||||
if err != nil {
|
||||
log.Printf("Error deleting token: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否真的删除了记录
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
log.Printf("Error getting rows affected: %v", err)
|
||||
sendErrorResponse(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
sendErrorResponse(w, "Token not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
sendJSONResponse(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Token deleted successfully",
|
||||
}, http.StatusOK)
|
||||
}
|
||||
|
||||
// sendJSONResponse 发送JSON响应
|
||||
func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Printf("Error encoding response: %v", err)
|
||||
sendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue