diff --git a/cmd/root.go b/cmd/root.go index dbf3e25..9279368 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -110,8 +110,9 @@ func setupRouter(db *sqlx.DB) http.Handler { router := httprouter.New() router.POST("/login", utils.AdaptHandler(handler.Login)) - router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp)) - router.GET("/get", utils.AdaptHandler(handler.GetOtp)) + router.POST("/refresh", utils.AdaptHandler(utils.AuthMiddleware(handler.RefreshToken))) + router.POST("/set", utils.AdaptHandler(utils.AuthMiddleware(handler.UpdateOrCreateOtp))) + router.GET("/get", utils.AdaptHandler(utils.AuthMiddleware(handler.GetOtp))) return router } diff --git a/config.yaml b/config.yaml index a61e877..f411f36 100644 --- a/config.yaml +++ b/config.yaml @@ -8,6 +8,7 @@ port: 8080 auth: secret: "secret" ttl: 3600 + wechat: appid: "wx57d1033974eb5250" secret: "be494c2a81df685a40b9a74e1736b15d" diff --git a/database.zip b/database.zip new file mode 100644 index 0000000..5cf63c3 Binary files /dev/null and b/database.zip differ diff --git a/database/init/otp.sql b/database/init/otp.sql index e1c176c..1cec635 100644 --- a/database/init/otp.sql +++ b/database/init/otp.sql @@ -1,5 +1,6 @@ CREATE TABLE IF NOT EXISTS otp ( - openid VARCHAR(255) PRIMARY KEY, + id SERIAL PRIMARY KEY, + openid VARCHAR(255) UNIQUE NOT NULL, token VARCHAR(255), createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); \ No newline at end of file diff --git a/database/init/users.sql b/database/init/users.sql index 4dbedae..3f1d0f8 100644 --- a/database/init/users.sql +++ b/database/init/users.sql @@ -2,4 +2,5 @@ CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, openid VARCHAR(255) UNIQUE NOT NULL, session_key VARCHAR(255) UNIQUE NOT NULL -); \ No newline at end of file +); +CREATE UNIQUE INDEX idx_users_openid ON users(openid); \ No newline at end of file diff --git a/handlers/login.go b/handlers/login.go index ca41147..f31c1c9 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -11,13 +11,6 @@ import ( "github.com/spf13/viper" ) -var wxClient = &http.Client{ - Timeout: 10 * time.Second, - Transport: &http.Transport{ - MaxIdleConnsPerHost: 10, - }, -} - type LoginRequest struct { Code string `json:"code"` } @@ -31,6 +24,13 @@ type LoginResponse struct { ErrMsg string `json:"errmsg"` } +var wxClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + }, +} + func getLoginResponse(code string) (*LoginResponse, error) { appid := viper.GetString("wechat.appid") secret := viper.GetString("wechat.secret") @@ -47,16 +47,27 @@ func getLoginResponse(code string) (*LoginResponse, error) { } if loginResponse.ErrCode != 0 { - return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg) + switch loginResponse.ErrCode { + case 40029: + return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg) + case 45011: + return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg) + default: + return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg) + } } return &loginResponse, nil } func generateJWT(openid string) (string, error) { tokenTTL := viper.GetDuration("auth.ttl") + if tokenTTL <= 0 { + tokenTTL = 24 * time.Hour + } + secret := viper.GetString("auth.secret") if secret == "" { - return "", fmt.Errorf("auth secret not set") + secret = "default_auth_secret_otpm" } claims := jwt.MapClaims{ @@ -81,6 +92,8 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { WriteError(w, "Failed to read request body", http.StatusBadRequest) return } + defer r.Body.Close() + if err := json.Unmarshal(body, &req); err != nil { WriteError(w, "Failed to parse request body", http.StatusBadRequest) return @@ -88,7 +101,14 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { loginResponse, err := getLoginResponse(req.Code) if err != nil { - WriteError(w, "Failed to get login response", http.StatusInternalServerError) + switch { + case err.Error() == "invalid code": + WriteError(w, "Invalid code", http.StatusUnauthorized) + case err.Error() == "api limit exceeded": + WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests) + default: + WriteError(w, "Failed to get login response", http.StatusInternalServerError) + } return } @@ -113,10 +133,26 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { } data := map[string]interface{}{ - "token": token, - "openid": loginResponse.OpenId, - "session_key": loginResponse.SessionKey, + "t": token, + "openid": loginResponse.OpenId, } WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK) } + +func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) { + userid := r.Context().Value("openid").(string) + + token, err := generateJWT(userid) + if err != nil { + WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError) + return + } + WriteJSON(w, Response{ + Code: 0, + Message: "Token refreshed successfully", + Data: map[string]string{ + "token": token, + }, + }, http.StatusOK) +} diff --git a/handlers/otp.go b/handlers/otp.go index eb5789e..c8a81cd 100644 --- a/handlers/otp.go +++ b/handlers/otp.go @@ -8,7 +8,6 @@ import ( type OtpRequest struct { OpenID string `json:"openid"` - Num int `json:"num"` Token *[]OTP `json:"token"` } type OTP struct { @@ -61,7 +60,7 @@ func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) { var otp OtpRequest - err := h.DB.Get(&otp, "SELECT openid, token FROM otp WHERE openid=$1", openid) + err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid) if err != nil { WriteError(w, "Failed to get OTP", http.StatusInternalServerError) return diff --git a/utils/utils.go b/utils/utils.go index 113ac27..efd6713 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,20 +1,65 @@ package utils import ( + "context" "crypto/aes" "crypto/cipher" "encoding/base64" + "fmt" "net/http" + "strings" + "github.com/golang-jwt/jwt" "github.com/julienschmidt/httprouter" + "github.com/spf13/viper" ) +// AdaptHandler函数将一个http.Handler转换为httprouter.Handle func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle { + // 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数 return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + // 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递 h(w, r) } } +func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized) + return + } + + tokenStr := strings.TrimPrefix(authHeader, "Bearer ") + secret := viper.GetString("auth.secret") + + token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") + } + return []byte(secret), nil + }) + + if err != nil || !token.Valid { + http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized) + return + } + + type contextKey string + // 将 openid 存入上下文 + ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"]) + next.ServeHTTP(w, r.WithContext(ctx)) + } +} + +// AesDecrypt 函数用于AES解密 func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) { //Base64解码 keyBytes, err := base64.StdEncoding.DecodeString(sessionKey) @@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) { return origData, nil } +// PKCS7UnPadding 函数用于去除PKCS7填充的密文 func PKCS7UnPadding(plantText []byte) []byte { + // 获取密文的长度 length := len(plantText) + // 如果密文长度大于0 if length > 0 { + // 获取最后一个字节的值,即填充的位数 unPadding := int(plantText[length-1]) + // 返回去除填充后的密文 return plantText[:(length - unPadding)] } + // 如果密文长度为0,则返回原密文 return plantText }