package main import ( "crypto/aes" "crypto/cipher" "crypto/rand" "database/sql" "encoding/base64" "encoding/json" "fmt" "io" "log" "net/http" _ "github.com/lib/pq" ) // 加密密钥(32字节AES-256) var encryptionKey = []byte("example-key-32-bytes-long!1234") // 实际应用中应从安全配置获取 // encryptTokenSecret 加密令牌密钥 func encryptTokenSecret(secret string) (string, error) { block, err := aes.NewCipher(encryptionKey) if err != nil { return "", err } ciphertext := make([]byte, aes.BlockSize+len(secret)) iv := ciphertext[:aes.BlockSize] if _, err := io.ReadFull(rand.Reader, iv); err != nil { return "", err } stream := cipher.NewCFBEncrypter(block, iv) stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(secret)) return base64.StdEncoding.EncodeToString(ciphertext), nil } // decryptTokenSecret 解密令牌密钥 func decryptTokenSecret(encrypted string) (string, error) { ciphertext, err := base64.StdEncoding.DecodeString(encrypted) if err != nil { return "", err } block, err := aes.NewCipher(encryptionKey) if err != nil { return "", err } if len(ciphertext) < aes.BlockSize { return "", fmt.Errorf("ciphertext too short") } iv := ciphertext[:aes.BlockSize] ciphertext = ciphertext[aes.BlockSize:] stream := cipher.NewCFBDecrypter(block, iv) stream.XORKeyStream(ciphertext, ciphertext) return string(ciphertext), nil } // SaveRequest 保存请求的数据结构 type SaveRequest struct { Tokens []TokenData `json:"tokens"` UserID string `json:"userId"` Timestamp int64 `json:"timestamp"` } // TokenData token数据结构 type TokenData struct { ID string `json:"id"` Issuer string `json:"issuer"` Account string `json:"account"` Secret string `json:"secret"` Type string `json:"type"` Counter int `json:"counter,omitempty"` Period int `json:"period"` Digits int `json:"digits"` Algo string `json:"algo"` } // SaveResponse 保存响应的数据结构 type SaveResponse struct { Success bool `json:"success"` Message string `json:"message"` Data struct { ID string `json:"id"` } `json:"data"` } // RecoverResponse 恢复响应的数据结构 type RecoverResponse struct { Success bool `json:"success"` Message string `json:"message"` Data struct { Tokens []TokenData `json:"tokens"` Timestamp int64 `json:"timestamp"` } `json:"data"` } var db *sql.DB // InitDB 初始化数据库连接 func InitDB() error { connStr := "postgres://postgres:postgres@localhost/otp_db?sslmode=disable" var err error db, err = sql.Open("postgres", connStr) if err != nil { return fmt.Errorf("error opening database: %v", err) } if err = db.Ping(); err != nil { return fmt.Errorf("error connecting to the database: %v", err) } return nil } // SaveHandler 保存token的接口处理函数 func SaveHandler(w http.ResponseWriter, r *http.Request) { // 设置CORS头 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") // 处理OPTIONS请求 if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // 检查请求方法 if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求 var req SaveRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { log.Printf("Error decoding request: %v", err) sendErrorResponse(w, "Invalid request body", http.StatusBadRequest) return } // 验证请求数据 if req.UserID == "" { sendErrorResponse(w, "Missing user ID", http.StatusBadRequest) return } if len(req.Tokens) == 0 { sendErrorResponse(w, "No tokens provided", http.StatusBadRequest) return } // 开始数据库事务 tx, err := db.Begin() if err != nil { log.Printf("Error starting transaction: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } defer tx.Rollback() // 删除用户现有的tokens _, err = tx.Exec("DELETE FROM tokens WHERE user_id = $1", req.UserID) if err != nil { log.Printf("Error deleting existing tokens: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } // 插入新的tokens stmt, err := tx.Prepare(` INSERT INTO tokens (id, user_id, issuer, account, secret, type, counter, period, digits, algo, timestamp) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) `) if err != nil { log.Printf("Error preparing statement: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } defer stmt.Close() for _, token := range req.Tokens { // 加密secret encryptedSecret, err := encryptTokenSecret(token.Secret) if err != nil { log.Printf("Error encrypting token secret: %v", err) sendErrorResponse(w, "Encryption error", http.StatusInternalServerError) return } _, err = stmt.Exec( token.ID, req.UserID, token.Issuer, token.Account, encryptedSecret, token.Type, token.Counter, token.Period, token.Digits, token.Algo, req.Timestamp, ) if err != nil { log.Printf("Error inserting token: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } } // 提交事务 if err = tx.Commit(); err != nil { log.Printf("Error committing transaction: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } // 返回成功响应 resp := SaveResponse{ Success: true, Message: "Tokens saved successfully", } resp.Data.ID = req.UserID sendJSONResponse(w, resp, http.StatusOK) } // RecoverHandler 恢复token的接口处理函数 func RecoverHandler(w http.ResponseWriter, r *http.Request) { // 设置CORS头 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") // 处理OPTIONS请求 if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // 检查请求方法 if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求体 var req struct { UserID string `json:"userId"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { log.Printf("Error decoding request: %v", err) sendErrorResponse(w, "Invalid request body", http.StatusBadRequest) return } // 验证用户ID if req.UserID == "" { sendErrorResponse(w, "Missing user ID", http.StatusBadRequest) return } // 查询数据库 rows, err := db.Query(` SELECT id, issuer, account, secret, type, counter, period, digits, algo, timestamp FROM tokens WHERE user_id = $1 ORDER BY timestamp DESC `, req.UserID) if err != nil { log.Printf("Error querying database: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } defer rows.Close() // 读取查询结果 var tokens []TokenData var timestamp int64 for rows.Next() { var token TokenData var encryptedSecret string err := rows.Scan( &token.ID, &token.Issuer, &token.Account, &encryptedSecret, &token.Type, &token.Counter, &token.Period, &token.Digits, &token.Algo, ×tamp, ) if err != nil { log.Printf("Error scanning row: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } // 解密secret token.Secret, err = decryptTokenSecret(encryptedSecret) if err != nil { log.Printf("Error decrypting token secret: %v", err) sendErrorResponse(w, "Decryption error", http.StatusInternalServerError) return } tokens = append(tokens, token) } if err = rows.Err(); err != nil { log.Printf("Error iterating rows: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } // 返回响应 resp := RecoverResponse{ Success: true, Message: "Tokens recovered successfully", } resp.Data.Tokens = tokens resp.Data.Timestamp = timestamp sendJSONResponse(w, resp, http.StatusOK) } // sendErrorResponse 发送错误响应 func sendErrorResponse(w http.ResponseWriter, message string, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]interface{}{ "success": false, "message": message, }) } // DeleteTokenHandler 删除单个token的接口处理函数 func DeleteTokenHandler(w http.ResponseWriter, r *http.Request) { // 设置CORS头 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") // 处理OPTIONS请求 if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // 检查请求方法 if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求 var req struct { UserID string `json:"userId"` TokenID string `json:"tokenId"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { log.Printf("Error decoding request: %v", err) sendErrorResponse(w, "Invalid request body", http.StatusBadRequest) return } // 验证请求数据 if req.UserID == "" || req.TokenID == "" { sendErrorResponse(w, "Missing user ID or token ID", http.StatusBadRequest) return } // 执行删除操作 result, err := db.Exec("DELETE FROM tokens WHERE user_id = $1 AND id = $2", req.UserID, req.TokenID) if err != nil { log.Printf("Error deleting token: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } // 检查是否真的删除了记录 rowsAffected, err := result.RowsAffected() if err != nil { log.Printf("Error getting rows affected: %v", err) sendErrorResponse(w, "Database error", http.StatusInternalServerError) return } if rowsAffected == 0 { sendErrorResponse(w, "Token not found", http.StatusNotFound) return } // 返回成功响应 sendJSONResponse(w, map[string]interface{}{ "success": true, "message": "Token deleted successfully", }, http.StatusOK) } // sendJSONResponse 发送JSON响应 func sendJSONResponse(w http.ResponseWriter, data interface{}, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(data); err != nil { log.Printf("Error encoding response: %v", err) sendErrorResponse(w, "Internal server error", http.StatusInternalServerError) } }