This commit is contained in:
“xHuPo” 2025-05-23 13:45:37 +08:00
parent 25c5f530b8
commit 2d3698716e
8 changed files with 109 additions and 19 deletions

View file

@ -110,8 +110,9 @@ func setupRouter(db *sqlx.DB) http.Handler {
router := httprouter.New()
router.POST("/login", utils.AdaptHandler(handler.Login))
router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
router.GET("/get", utils.AdaptHandler(handler.GetOtp))
router.POST("/refresh", utils.AdaptHandler(utils.AuthMiddleware(handler.RefreshToken)))
router.POST("/set", utils.AdaptHandler(utils.AuthMiddleware(handler.UpdateOrCreateOtp)))
router.GET("/get", utils.AdaptHandler(utils.AuthMiddleware(handler.GetOtp)))
return router
}

View file

@ -8,6 +8,7 @@ port: 8080
auth:
secret: "secret"
ttl: 3600
wechat:
appid: "wx57d1033974eb5250"
secret: "be494c2a81df685a40b9a74e1736b15d"

BIN
database.zip Normal file

Binary file not shown.

View file

@ -1,5 +1,6 @@
CREATE TABLE IF NOT EXISTS otp (
openid VARCHAR(255) PRIMARY KEY,
id SERIAL PRIMARY KEY,
openid VARCHAR(255) UNIQUE NOT NULL,
token VARCHAR(255),
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

View file

@ -2,4 +2,5 @@ CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
);
);
CREATE UNIQUE INDEX idx_users_openid ON users(openid);

View file

@ -11,13 +11,6 @@ import (
"github.com/spf13/viper"
)
var wxClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
},
}
type LoginRequest struct {
Code string `json:"code"`
}
@ -31,6 +24,13 @@ type LoginResponse struct {
ErrMsg string `json:"errmsg"`
}
var wxClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
},
}
func getLoginResponse(code string) (*LoginResponse, error) {
appid := viper.GetString("wechat.appid")
secret := viper.GetString("wechat.secret")
@ -47,16 +47,27 @@ func getLoginResponse(code string) (*LoginResponse, error) {
}
if loginResponse.ErrCode != 0 {
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
switch loginResponse.ErrCode {
case 40029:
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
case 45011:
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
default:
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
}
}
return &loginResponse, nil
}
func generateJWT(openid string) (string, error) {
tokenTTL := viper.GetDuration("auth.ttl")
if tokenTTL <= 0 {
tokenTTL = 24 * time.Hour
}
secret := viper.GetString("auth.secret")
if secret == "" {
return "", fmt.Errorf("auth secret not set")
secret = "default_auth_secret_otpm"
}
claims := jwt.MapClaims{
@ -81,6 +92,8 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
WriteError(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
if err := json.Unmarshal(body, &req); err != nil {
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
return
@ -88,7 +101,14 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
loginResponse, err := getLoginResponse(req.Code)
if err != nil {
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
switch {
case err.Error() == "invalid code":
WriteError(w, "Invalid code", http.StatusUnauthorized)
case err.Error() == "api limit exceeded":
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
default:
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
}
return
}
@ -113,10 +133,26 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
data := map[string]interface{}{
"token": token,
"openid": loginResponse.OpenId,
"session_key": loginResponse.SessionKey,
"t": token,
"openid": loginResponse.OpenId,
}
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
}
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
userid := r.Context().Value("openid").(string)
token, err := generateJWT(userid)
if err != nil {
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
return
}
WriteJSON(w, Response{
Code: 0,
Message: "Token refreshed successfully",
Data: map[string]string{
"token": token,
},
}, http.StatusOK)
}

View file

@ -8,7 +8,6 @@ import (
type OtpRequest struct {
OpenID string `json:"openid"`
Num int `json:"num"`
Token *[]OTP `json:"token"`
}
type OTP struct {
@ -61,7 +60,7 @@ func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
var otp OtpRequest
err := h.DB.Get(&otp, "SELECT openid, token FROM otp WHERE openid=$1", openid)
err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid)
if err != nil {
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
return

View file

@ -1,20 +1,65 @@
package utils
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
"github.com/spf13/viper"
)
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
// 返回一个httprouter.Handle函数该函数接受http.ResponseWriter和*http.Request作为参数
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// 调用传入的http.Handler函数将http.ResponseWriter和*http.Request作为参数传递
h(w, r)
}
}
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
return
}
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
secret := viper.GetString("auth.secret")
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(secret), nil
})
if err != nil || !token.Valid {
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
return
}
type contextKey string
// 将 openid 存入上下文
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
next.ServeHTTP(w, r.WithContext(ctx))
}
}
// AesDecrypt 函数用于AES解密
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
//Base64解码
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
return origData, nil
}
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
func PKCS7UnPadding(plantText []byte) []byte {
// 获取密文的长度
length := len(plantText)
// 如果密文长度大于0
if length > 0 {
// 获取最后一个字节的值,即填充的位数
unPadding := int(plantText[length-1])
// 返回去除填充后的密文
return plantText[:(length - unPadding)]
}
// 如果密文长度为0则返回原密文
return plantText
}