121 lines
3 KiB
Go
121 lines
3 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
)
|
||
|
||
func TestSaveHandler(t *testing.T) {
|
||
// 创建测试服务器
|
||
srv := httptest.NewServer(http.HandlerFunc(SaveHandler))
|
||
defer srv.Close()
|
||
|
||
// 准备测试数据
|
||
counter := 0 // 创建一个int变量
|
||
testData := SaveRequest{
|
||
UserID: "test_user_123",
|
||
Tokens: []TokenData{
|
||
{
|
||
Issuer: "TestOrg",
|
||
Account: "user@test.com",
|
||
Secret: "JBSWY3DPEHPK3PXP",
|
||
Type: "totp",
|
||
Period: 30,
|
||
Digits: 6,
|
||
Algorithm: "SHA1",
|
||
},
|
||
{
|
||
Issuer: "TestOrgHOTP",
|
||
Account: "user@test.com",
|
||
Secret: "JBSWY3DPEHPK3PXP",
|
||
Type: "hotp",
|
||
Counter: &counter, // 使用指针
|
||
Digits: 6,
|
||
Algorithm: "SHA1",
|
||
},
|
||
},
|
||
}
|
||
|
||
// 序列化请求体
|
||
body, _ := json.Marshal(testData)
|
||
|
||
// 发送请求
|
||
resp, err := http.Post(srv.URL, "application/json", bytes.NewBuffer(body))
|
||
if err != nil {
|
||
t.Fatalf("Error making request to server: %v\n", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 检查响应状态码
|
||
if resp.StatusCode != http.StatusOK {
|
||
t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, resp.StatusCode)
|
||
}
|
||
|
||
// 解析响应
|
||
var saveResp SaveResponse
|
||
err = json.NewDecoder(resp.Body).Decode(&saveResp)
|
||
if err != nil {
|
||
t.Errorf("Error decoding response: %v\n", err)
|
||
}
|
||
|
||
// 验证响应数据
|
||
if !saveResp.Success {
|
||
t.Errorf("Expected success to be true, got false\n")
|
||
}
|
||
if saveResp.Message != "Tokens saved successfully" {
|
||
t.Errorf("Expected message to be 'Tokens saved successfully', got '%s'\n", saveResp.Message)
|
||
}
|
||
}
|
||
|
||
func TestRecoverHandler(t *testing.T) {
|
||
// 创建测试服务器
|
||
srv := httptest.NewServer(http.HandlerFunc(RecoverHandler))
|
||
defer srv.Close()
|
||
|
||
// 发送请求(没有user_id参数)
|
||
resp, err := http.Get(srv.URL)
|
||
if err != nil {
|
||
t.Fatalf("Error making request to server: %v\n", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 检查响应状态码(应该返回错误)
|
||
if resp.StatusCode != http.StatusBadRequest {
|
||
t.Errorf("Expected status: %d, got: %d\n", http.StatusBadRequest, resp.StatusCode)
|
||
}
|
||
|
||
// 发送带user_id的请求
|
||
urlWithID := fmt.Sprintf("%s?user_id=test_user_123", srv.URL)
|
||
respWithID, err := http.Get(urlWithID)
|
||
if err != nil {
|
||
t.Fatalf("Error making request to server: %v\n", err)
|
||
}
|
||
defer respWithID.Body.Close()
|
||
|
||
// 检查响应状态码
|
||
if respWithID.StatusCode != http.StatusOK {
|
||
t.Errorf("Expected status: %d, got: %d\n", http.StatusOK, respWithID.StatusCode)
|
||
}
|
||
|
||
// 解析响应
|
||
var recoverResp RecoverResponse
|
||
err = json.NewDecoder(respWithID.Body).Decode(&recoverResp)
|
||
if err != nil {
|
||
t.Errorf("Error decoding response: %v\n", err)
|
||
}
|
||
|
||
// 验证响应数据
|
||
if !recoverResp.Success {
|
||
t.Errorf("Expected success to be true, got false\n")
|
||
}
|
||
if recoverResp.Message != "Tokens recovered successfully" {
|
||
t.Errorf("Expected message to be 'Tokens recovered successfully', got '%s'\n", recoverResp.Message)
|
||
}
|
||
if len(recoverResp.Data.Tokens) != 1 {
|
||
t.Errorf("Expected 1 token, got %d\n", len(recoverResp.Data.Tokens))
|
||
}
|
||
}
|