105 lines
2.8 KiB
Go
105 lines
2.8 KiB
Go
package utils
|
||
|
||
import (
|
||
"context"
|
||
"crypto/aes"
|
||
"crypto/cipher"
|
||
"encoding/base64"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/golang-jwt/jwt"
|
||
"github.com/julienschmidt/httprouter"
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
|
||
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
|
||
// 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数
|
||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||
// 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递
|
||
h(w, r)
|
||
}
|
||
}
|
||
|
||
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
authHeader := r.Header.Get("Authorization")
|
||
if authHeader == "" {
|
||
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||
secret := viper.GetString("auth.secret")
|
||
|
||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, fmt.Errorf("unexpected signing method")
|
||
}
|
||
return []byte(secret), nil
|
||
})
|
||
|
||
if err != nil || !token.Valid {
|
||
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
claims, ok := token.Claims.(jwt.MapClaims)
|
||
if !ok {
|
||
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
type contextKey string
|
||
// 将 openid 存入上下文
|
||
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
}
|
||
}
|
||
|
||
// AesDecrypt 函数用于AES解密
|
||
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
||
//Base64解码
|
||
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
ivBytes, err := base64.StdEncoding.DecodeString(iv)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
cryptData, err := base64.StdEncoding.DecodeString(encryptedData)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
origData := make([]byte, len(cryptData))
|
||
//AES
|
||
block, err := aes.NewCipher(keyBytes)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
//CBC
|
||
mode := cipher.NewCBCDecrypter(block, ivBytes)
|
||
//解密
|
||
mode.CryptBlocks(origData, cryptData)
|
||
//去除填充位
|
||
origData = PKCS7UnPadding(origData)
|
||
return origData, nil
|
||
}
|
||
|
||
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
|
||
func PKCS7UnPadding(plantText []byte) []byte {
|
||
// 获取密文的长度
|
||
length := len(plantText)
|
||
// 如果密文长度大于0
|
||
if length > 0 {
|
||
// 获取最后一个字节的值,即填充的位数
|
||
unPadding := int(plantText[length-1])
|
||
// 返回去除填充后的密文
|
||
return plantText[:(length - unPadding)]
|
||
}
|
||
// 如果密文长度为0,则返回原密文
|
||
return plantText
|
||
}
|