This commit is contained in:
“xHuPo” 2025-05-23 13:45:37 +08:00
parent 25c5f530b8
commit 2d3698716e
8 changed files with 109 additions and 19 deletions

View file

@ -11,13 +11,6 @@ import (
"github.com/spf13/viper"
)
var wxClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
},
}
type LoginRequest struct {
Code string `json:"code"`
}
@ -31,6 +24,13 @@ type LoginResponse struct {
ErrMsg string `json:"errmsg"`
}
var wxClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
},
}
func getLoginResponse(code string) (*LoginResponse, error) {
appid := viper.GetString("wechat.appid")
secret := viper.GetString("wechat.secret")
@ -47,16 +47,27 @@ func getLoginResponse(code string) (*LoginResponse, error) {
}
if loginResponse.ErrCode != 0 {
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
switch loginResponse.ErrCode {
case 40029:
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
case 45011:
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
default:
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
}
}
return &loginResponse, nil
}
func generateJWT(openid string) (string, error) {
tokenTTL := viper.GetDuration("auth.ttl")
if tokenTTL <= 0 {
tokenTTL = 24 * time.Hour
}
secret := viper.GetString("auth.secret")
if secret == "" {
return "", fmt.Errorf("auth secret not set")
secret = "default_auth_secret_otpm"
}
claims := jwt.MapClaims{
@ -81,6 +92,8 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
WriteError(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
if err := json.Unmarshal(body, &req); err != nil {
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
return
@ -88,7 +101,14 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
loginResponse, err := getLoginResponse(req.Code)
if err != nil {
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
switch {
case err.Error() == "invalid code":
WriteError(w, "Invalid code", http.StatusUnauthorized)
case err.Error() == "api limit exceeded":
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
default:
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
}
return
}
@ -113,10 +133,26 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
data := map[string]interface{}{
"token": token,
"openid": loginResponse.OpenId,
"session_key": loginResponse.SessionKey,
"t": token,
"openid": loginResponse.OpenId,
}
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
}
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
userid := r.Context().Value("openid").(string)
token, err := generateJWT(userid)
if err != nil {
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
return
}
WriteJSON(w, Response{
Code: 0,
Message: "Token refreshed successfully",
Data: map[string]string{
"token": token,
},
}, http.StatusOK)
}

View file

@ -8,7 +8,6 @@ import (
type OtpRequest struct {
OpenID string `json:"openid"`
Num int `json:"num"`
Token *[]OTP `json:"token"`
}
type OTP struct {
@ -61,7 +60,7 @@ func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
var otp OtpRequest
err := h.DB.Get(&otp, "SELECT openid, token FROM otp WHERE openid=$1", openid)
err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid)
if err != nil {
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
return