package handlers import ( "encoding/json" "fmt" "io" "net/http" "time" "github.com/golang-jwt/jwt" "github.com/spf13/viper" ) type LoginRequest struct { Code string `json:"code"` } // 封装code2session接口返回数据 type LoginResponse struct { OpenId string `json:"openid"` SessionKey string `json:"session_key"` UnionId string `json:"unionid"` ErrCode int `json:"errcode"` ErrMsg string `json:"errmsg"` } var wxClient = &http.Client{ Timeout: 10 * time.Second, Transport: &http.Transport{ MaxIdleConnsPerHost: 10, }, } func getLoginResponse(code string) (*LoginResponse, error) { appid := viper.GetString("wechat.appid") secret := viper.GetString("wechat.secret") url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code) resp, err := wxClient.Get(url) if err != nil { return nil, err } defer resp.Body.Close() var loginResponse LoginResponse if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil { return nil, err } if loginResponse.ErrCode != 0 { switch loginResponse.ErrCode { case 40029: return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg) case 45011: return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg) default: return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg) } } return &loginResponse, nil } func generateJWT(openid string) (string, error) { tokenTTL := viper.GetDuration("auth.ttl") if tokenTTL <= 0 { tokenTTL = 24 * time.Hour } secret := viper.GetString("auth.secret") if secret == "" { secret = "default_auth_secret_otpm" } claims := jwt.MapClaims{ "openid": openid, "exp": time.Now().Add(tokenTTL).Unix(), "iat": time.Now().Unix(), "iss": viper.GetString("server.name"), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signedToken, err := token.SignedString([]byte(secret)) if err != nil { return "", err } return signedToken, nil } func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { var req LoginRequest body, err := io.ReadAll(r.Body) if err != nil { WriteError(w, "Failed to read request body", http.StatusBadRequest) return } defer r.Body.Close() if err := json.Unmarshal(body, &req); err != nil { WriteError(w, "Failed to parse request body", http.StatusBadRequest) return } loginResponse, err := getLoginResponse(req.Code) if err != nil { switch { case err.Error() == "invalid code": WriteError(w, "Invalid code", http.StatusUnauthorized) case err.Error() == "api limit exceeded": WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests) default: WriteError(w, "Failed to get login response", http.StatusInternalServerError) } return } // 插入或更新用户的openid和session_key query := ` INSERT INTO users (openid, session_key) VALUES ($1, $2) ON CONFLICT (openid) DO UPDATE SET session_key = $2 RETURNING id; ` var ID int if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil { WriteError(w, "Failed to log in user", http.StatusInternalServerError) return } token, err := generateJWT(loginResponse.OpenId) if err != nil { WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError) return } data := map[string]interface{}{ "t": token, "openid": loginResponse.OpenId, } WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK) } func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) { userid := r.Context().Value("openid").(string) token, err := generateJWT(userid) if err != nil { WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError) return } WriteJSON(w, Response{ Code: 0, Message: "Token refreshed successfully", Data: map[string]string{ "token": token, }, }, http.StatusOK) }