package utils import ( "context" "crypto/aes" "crypto/cipher" "encoding/base64" "fmt" "net/http" "strings" "github.com/golang-jwt/jwt" "github.com/julienschmidt/httprouter" "github.com/spf13/viper" ) // AdaptHandler函数将一个http.Handler转换为httprouter.Handle func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle { // 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数 return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { // 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递 h(w, r) } } func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized) return } tokenStr := strings.TrimPrefix(authHeader, "Bearer ") secret := viper.GetString("auth.secret") token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return []byte(secret), nil }) if err != nil || !token.Valid { http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized) return } type contextKey string // 将 openid 存入上下文 ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"]) next.ServeHTTP(w, r.WithContext(ctx)) } } // AesDecrypt 函数用于AES解密 func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) { //Base64解码 keyBytes, err := base64.StdEncoding.DecodeString(sessionKey) if err != nil { return nil, err } ivBytes, err := base64.StdEncoding.DecodeString(iv) if err != nil { return nil, err } cryptData, err := base64.StdEncoding.DecodeString(encryptedData) if err != nil { return nil, err } origData := make([]byte, len(cryptData)) //AES block, err := aes.NewCipher(keyBytes) if err != nil { return nil, err } //CBC mode := cipher.NewCBCDecrypter(block, ivBytes) //解密 mode.CryptBlocks(origData, cryptData) //去除填充位 origData = PKCS7UnPadding(origData) return origData, nil } // PKCS7UnPadding 函数用于去除PKCS7填充的密文 func PKCS7UnPadding(plantText []byte) []byte { // 获取密文的长度 length := len(plantText) // 如果密文长度大于0 if length > 0 { // 获取最后一个字节的值,即填充的位数 unPadding := int(plantText[length-1]) // 返回去除填充后的密文 return plantText[:(length - unPadding)] } // 如果密文长度为0,则返回原密文 return plantText }