otpm/handlers/login.go
“xHuPo” 2d3698716e alpha
2025-05-23 13:45:37 +08:00

158 lines
3.9 KiB
Go

package handlers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/golang-jwt/jwt"
"github.com/spf13/viper"
)
type LoginRequest struct {
Code string `json:"code"`
}
// 封装code2session接口返回数据
type LoginResponse struct {
OpenId string `json:"openid"`
SessionKey string `json:"session_key"`
UnionId string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
var wxClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
},
}
func getLoginResponse(code string) (*LoginResponse, error) {
appid := viper.GetString("wechat.appid")
secret := viper.GetString("wechat.secret")
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code)
resp, err := wxClient.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var loginResponse LoginResponse
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
return nil, err
}
if loginResponse.ErrCode != 0 {
switch loginResponse.ErrCode {
case 40029:
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
case 45011:
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
default:
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
}
}
return &loginResponse, nil
}
func generateJWT(openid string) (string, error) {
tokenTTL := viper.GetDuration("auth.ttl")
if tokenTTL <= 0 {
tokenTTL = 24 * time.Hour
}
secret := viper.GetString("auth.secret")
if secret == "" {
secret = "default_auth_secret_otpm"
}
claims := jwt.MapClaims{
"openid": openid,
"exp": time.Now().Add(tokenTTL).Unix(),
"iat": time.Now().Unix(),
"iss": viper.GetString("server.name"),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return "", err
}
return signedToken, nil
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
body, err := io.ReadAll(r.Body)
if err != nil {
WriteError(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
if err := json.Unmarshal(body, &req); err != nil {
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
return
}
loginResponse, err := getLoginResponse(req.Code)
if err != nil {
switch {
case err.Error() == "invalid code":
WriteError(w, "Invalid code", http.StatusUnauthorized)
case err.Error() == "api limit exceeded":
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
default:
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
}
return
}
// 插入或更新用户的openid和session_key
query := `
INSERT INTO users (openid, session_key)
VALUES ($1, $2)
ON CONFLICT (openid) DO UPDATE SET session_key = $2
RETURNING id;
`
var ID int
if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
WriteError(w, "Failed to log in user", http.StatusInternalServerError)
return
}
token, err := generateJWT(loginResponse.OpenId)
if err != nil {
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
return
}
data := map[string]interface{}{
"t": token,
"openid": loginResponse.OpenId,
}
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
}
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
userid := r.Context().Value("openid").(string)
token, err := generateJWT(userid)
if err != nil {
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
return
}
WriteJSON(w, Response{
Code: 0,
Message: "Token refreshed successfully",
Data: map[string]string{
"token": token,
},
}, http.StatusOK)
}