122 lines
3 KiB
Go
122 lines
3 KiB
Go
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
var wxClient = &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
Transport: &http.Transport{
|
|
MaxIdleConnsPerHost: 10,
|
|
},
|
|
}
|
|
|
|
type LoginRequest struct {
|
|
Code string `json:"code"`
|
|
}
|
|
|
|
// 封装code2session接口返回数据
|
|
type LoginResponse struct {
|
|
OpenId string `json:"openid"`
|
|
SessionKey string `json:"session_key"`
|
|
UnionId string `json:"unionid"`
|
|
ErrCode int `json:"errcode"`
|
|
ErrMsg string `json:"errmsg"`
|
|
}
|
|
|
|
func getLoginResponse(code string) (*LoginResponse, error) {
|
|
appid := viper.GetString("wechat.appid")
|
|
secret := viper.GetString("wechat.secret")
|
|
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code)
|
|
resp, err := wxClient.Get(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var loginResponse LoginResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if loginResponse.ErrCode != 0 {
|
|
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
|
|
}
|
|
return &loginResponse, nil
|
|
}
|
|
|
|
func generateJWT(openid string) (string, error) {
|
|
tokenTTL := viper.GetDuration("auth.ttl")
|
|
secret := viper.GetString("auth.secret")
|
|
if secret == "" {
|
|
return "", fmt.Errorf("auth secret not set")
|
|
}
|
|
|
|
claims := jwt.MapClaims{
|
|
"openid": openid,
|
|
"exp": time.Now().Add(tokenTTL).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"iss": viper.GetString("server.name"),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signedToken, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return signedToken, nil
|
|
}
|
|
|
|
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
|
var req LoginRequest
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
WriteError(w, "Failed to read request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
loginResponse, err := getLoginResponse(req.Code)
|
|
if err != nil {
|
|
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 插入或更新用户的openid和session_key
|
|
query := `
|
|
INSERT INTO users (openid, session_key)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT (openid) DO UPDATE SET session_key = $2
|
|
RETURNING id;
|
|
`
|
|
|
|
var ID int
|
|
if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
|
|
WriteError(w, "Failed to log in user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
token, err := generateJWT(loginResponse.OpenId)
|
|
if err != nil {
|
|
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
data := map[string]interface{}{
|
|
"token": token,
|
|
"openid": loginResponse.OpenId,
|
|
"session_key": loginResponse.SessionKey,
|
|
}
|
|
|
|
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
|
|
}
|