alpha
This commit is contained in:
parent
25c5f530b8
commit
2d3698716e
8 changed files with 109 additions and 19 deletions
|
@ -110,8 +110,9 @@ func setupRouter(db *sqlx.DB) http.Handler {
|
|||
|
||||
router := httprouter.New()
|
||||
router.POST("/login", utils.AdaptHandler(handler.Login))
|
||||
router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
|
||||
router.GET("/get", utils.AdaptHandler(handler.GetOtp))
|
||||
router.POST("/refresh", utils.AdaptHandler(utils.AuthMiddleware(handler.RefreshToken)))
|
||||
router.POST("/set", utils.AdaptHandler(utils.AuthMiddleware(handler.UpdateOrCreateOtp)))
|
||||
router.GET("/get", utils.AdaptHandler(utils.AuthMiddleware(handler.GetOtp)))
|
||||
|
||||
return router
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ port: 8080
|
|||
auth:
|
||||
secret: "secret"
|
||||
ttl: 3600
|
||||
|
||||
wechat:
|
||||
appid: "wx57d1033974eb5250"
|
||||
secret: "be494c2a81df685a40b9a74e1736b15d"
|
||||
|
|
BIN
database.zip
Normal file
BIN
database.zip
Normal file
Binary file not shown.
|
@ -1,5 +1,6 @@
|
|||
CREATE TABLE IF NOT EXISTS otp (
|
||||
openid VARCHAR(255) PRIMARY KEY,
|
||||
id SERIAL PRIMARY KEY,
|
||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||
token VARCHAR(255),
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
|
@ -2,4 +2,5 @@ CREATE TABLE IF NOT EXISTS users (
|
|||
id SERIAL PRIMARY KEY,
|
||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||
session_key VARCHAR(255) UNIQUE NOT NULL
|
||||
);
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_users_openid ON users(openid);
|
|
@ -11,13 +11,6 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var wxClient = &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConnsPerHost: 10,
|
||||
},
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
@ -31,6 +24,13 @@ type LoginResponse struct {
|
|||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
var wxClient = &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConnsPerHost: 10,
|
||||
},
|
||||
}
|
||||
|
||||
func getLoginResponse(code string) (*LoginResponse, error) {
|
||||
appid := viper.GetString("wechat.appid")
|
||||
secret := viper.GetString("wechat.secret")
|
||||
|
@ -47,16 +47,27 @@ func getLoginResponse(code string) (*LoginResponse, error) {
|
|||
}
|
||||
|
||||
if loginResponse.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
|
||||
switch loginResponse.ErrCode {
|
||||
case 40029:
|
||||
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
|
||||
case 45011:
|
||||
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
|
||||
default:
|
||||
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
|
||||
}
|
||||
}
|
||||
return &loginResponse, nil
|
||||
}
|
||||
|
||||
func generateJWT(openid string) (string, error) {
|
||||
tokenTTL := viper.GetDuration("auth.ttl")
|
||||
if tokenTTL <= 0 {
|
||||
tokenTTL = 24 * time.Hour
|
||||
}
|
||||
|
||||
secret := viper.GetString("auth.secret")
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("auth secret not set")
|
||||
secret = "default_auth_secret_otpm"
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
|
@ -81,6 +92,8 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
|||
WriteError(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
|
||||
return
|
||||
|
@ -88,7 +101,14 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
loginResponse, err := getLoginResponse(req.Code)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
|
||||
switch {
|
||||
case err.Error() == "invalid code":
|
||||
WriteError(w, "Invalid code", http.StatusUnauthorized)
|
||||
case err.Error() == "api limit exceeded":
|
||||
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
|
||||
default:
|
||||
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -113,10 +133,26 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"token": token,
|
||||
"openid": loginResponse.OpenId,
|
||||
"session_key": loginResponse.SessionKey,
|
||||
"t": token,
|
||||
"openid": loginResponse.OpenId,
|
||||
}
|
||||
|
||||
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
userid := r.Context().Value("openid").(string)
|
||||
|
||||
token, err := generateJWT(userid)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
WriteJSON(w, Response{
|
||||
Code: 0,
|
||||
Message: "Token refreshed successfully",
|
||||
Data: map[string]string{
|
||||
"token": token,
|
||||
},
|
||||
}, http.StatusOK)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
type OtpRequest struct {
|
||||
OpenID string `json:"openid"`
|
||||
Num int `json:"num"`
|
||||
Token *[]OTP `json:"token"`
|
||||
}
|
||||
type OTP struct {
|
||||
|
@ -61,7 +60,7 @@ func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var otp OtpRequest
|
||||
|
||||
err := h.DB.Get(&otp, "SELECT openid, token FROM otp WHERE openid=$1", openid)
|
||||
err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
|
||||
return
|
||||
|
|
|
@ -1,20 +1,65 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
|
||||
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
|
||||
// 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
// 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
secret := viper.GetString("auth.secret")
|
||||
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method")
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
// 将 openid 存入上下文
|
||||
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// AesDecrypt 函数用于AES解密
|
||||
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
||||
//Base64解码
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
||||
|
@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
|||
return origData, nil
|
||||
}
|
||||
|
||||
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
|
||||
func PKCS7UnPadding(plantText []byte) []byte {
|
||||
// 获取密文的长度
|
||||
length := len(plantText)
|
||||
// 如果密文长度大于0
|
||||
if length > 0 {
|
||||
// 获取最后一个字节的值,即填充的位数
|
||||
unPadding := int(plantText[length-1])
|
||||
// 返回去除填充后的密文
|
||||
return plantText[:(length - unPadding)]
|
||||
}
|
||||
// 如果密文长度为0,则返回原密文
|
||||
return plantText
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue