alpha
This commit is contained in:
parent
25c5f530b8
commit
2d3698716e
8 changed files with 109 additions and 19 deletions
|
@ -110,8 +110,9 @@ func setupRouter(db *sqlx.DB) http.Handler {
|
||||||
|
|
||||||
router := httprouter.New()
|
router := httprouter.New()
|
||||||
router.POST("/login", utils.AdaptHandler(handler.Login))
|
router.POST("/login", utils.AdaptHandler(handler.Login))
|
||||||
router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp))
|
router.POST("/refresh", utils.AdaptHandler(utils.AuthMiddleware(handler.RefreshToken)))
|
||||||
router.GET("/get", utils.AdaptHandler(handler.GetOtp))
|
router.POST("/set", utils.AdaptHandler(utils.AuthMiddleware(handler.UpdateOrCreateOtp)))
|
||||||
|
router.GET("/get", utils.AdaptHandler(utils.AuthMiddleware(handler.GetOtp)))
|
||||||
|
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ port: 8080
|
||||||
auth:
|
auth:
|
||||||
secret: "secret"
|
secret: "secret"
|
||||||
ttl: 3600
|
ttl: 3600
|
||||||
|
|
||||||
wechat:
|
wechat:
|
||||||
appid: "wx57d1033974eb5250"
|
appid: "wx57d1033974eb5250"
|
||||||
secret: "be494c2a81df685a40b9a74e1736b15d"
|
secret: "be494c2a81df685a40b9a74e1736b15d"
|
||||||
|
|
BIN
database.zip
Normal file
BIN
database.zip
Normal file
Binary file not shown.
|
@ -1,5 +1,6 @@
|
||||||
CREATE TABLE IF NOT EXISTS otp (
|
CREATE TABLE IF NOT EXISTS otp (
|
||||||
openid VARCHAR(255) PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
|
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||||
token VARCHAR(255),
|
token VARCHAR(255),
|
||||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
|
@ -2,4 +2,5 @@ CREATE TABLE IF NOT EXISTS users (
|
||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||||
session_key VARCHAR(255) UNIQUE NOT NULL
|
session_key VARCHAR(255) UNIQUE NOT NULL
|
||||||
);
|
);
|
||||||
|
CREATE UNIQUE INDEX idx_users_openid ON users(openid);
|
|
@ -11,13 +11,6 @@ import (
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
var wxClient = &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
}
|
}
|
||||||
|
@ -31,6 +24,13 @@ type LoginResponse struct {
|
||||||
ErrMsg string `json:"errmsg"`
|
ErrMsg string `json:"errmsg"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wxClient = &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func getLoginResponse(code string) (*LoginResponse, error) {
|
func getLoginResponse(code string) (*LoginResponse, error) {
|
||||||
appid := viper.GetString("wechat.appid")
|
appid := viper.GetString("wechat.appid")
|
||||||
secret := viper.GetString("wechat.secret")
|
secret := viper.GetString("wechat.secret")
|
||||||
|
@ -47,16 +47,27 @@ func getLoginResponse(code string) (*LoginResponse, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginResponse.ErrCode != 0 {
|
if loginResponse.ErrCode != 0 {
|
||||||
return nil, fmt.Errorf("code2session error: %s", loginResponse.ErrMsg)
|
switch loginResponse.ErrCode {
|
||||||
|
case 40029:
|
||||||
|
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
|
||||||
|
case 45011:
|
||||||
|
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &loginResponse, nil
|
return &loginResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateJWT(openid string) (string, error) {
|
func generateJWT(openid string) (string, error) {
|
||||||
tokenTTL := viper.GetDuration("auth.ttl")
|
tokenTTL := viper.GetDuration("auth.ttl")
|
||||||
|
if tokenTTL <= 0 {
|
||||||
|
tokenTTL = 24 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
secret := viper.GetString("auth.secret")
|
secret := viper.GetString("auth.secret")
|
||||||
if secret == "" {
|
if secret == "" {
|
||||||
return "", fmt.Errorf("auth secret not set")
|
secret = "default_auth_secret_otpm"
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
|
@ -81,6 +92,8 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||||
WriteError(w, "Failed to read request body", http.StatusBadRequest)
|
WriteError(w, "Failed to read request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
|
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
@ -88,7 +101,14 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
loginResponse, err := getLoginResponse(req.Code)
|
loginResponse, err := getLoginResponse(req.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
|
switch {
|
||||||
|
case err.Error() == "invalid code":
|
||||||
|
WriteError(w, "Invalid code", http.StatusUnauthorized)
|
||||||
|
case err.Error() == "api limit exceeded":
|
||||||
|
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
|
||||||
|
default:
|
||||||
|
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,10 +133,26 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
data := map[string]interface{}{
|
data := map[string]interface{}{
|
||||||
"token": token,
|
"t": token,
|
||||||
"openid": loginResponse.OpenId,
|
"openid": loginResponse.OpenId,
|
||||||
"session_key": loginResponse.SessionKey,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
|
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userid := r.Context().Value("openid").(string)
|
||||||
|
|
||||||
|
token, err := generateJWT(userid)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteJSON(w, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: "Token refreshed successfully",
|
||||||
|
Data: map[string]string{
|
||||||
|
"token": token,
|
||||||
|
},
|
||||||
|
}, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
type OtpRequest struct {
|
type OtpRequest struct {
|
||||||
OpenID string `json:"openid"`
|
OpenID string `json:"openid"`
|
||||||
Num int `json:"num"`
|
|
||||||
Token *[]OTP `json:"token"`
|
Token *[]OTP `json:"token"`
|
||||||
}
|
}
|
||||||
type OTP struct {
|
type OTP struct {
|
||||||
|
@ -61,7 +60,7 @@ func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
var otp OtpRequest
|
var otp OtpRequest
|
||||||
|
|
||||||
err := h.DB.Get(&otp, "SELECT openid, token FROM otp WHERE openid=$1", openid)
|
err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
|
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,20 +1,65 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AdaptHandler函数将一个http.Handler转换为httprouter.Handle
|
||||||
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
|
func AdaptHandler(h func(http.ResponseWriter, *http.Request)) httprouter.Handle {
|
||||||
|
// 返回一个httprouter.Handle函数,该函数接受http.ResponseWriter和*http.Request作为参数
|
||||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||||
|
// 调用传入的http.Handler函数,将http.ResponseWriter和*http.Request作为参数传递
|
||||||
h(w, r)
|
h(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
http.Error(w, `{"error": "missing authorization token"}`, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
secret := viper.GetString("auth.secret")
|
||||||
|
|
||||||
|
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method")
|
||||||
|
}
|
||||||
|
return []byte(secret), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil || !token.Valid {
|
||||||
|
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, `{"error": "invalid claims"}`, http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
// 将 openid 存入上下文
|
||||||
|
ctx := context.WithValue(r.Context(), contextKey("openid"), claims["openid"])
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AesDecrypt 函数用于AES解密
|
||||||
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
||||||
//Base64解码
|
//Base64解码
|
||||||
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
keyBytes, err := base64.StdEncoding.DecodeString(sessionKey)
|
||||||
|
@ -44,11 +89,17 @@ func AesDecrypt(encryptedData, sessionKey, iv string) ([]byte, error) {
|
||||||
return origData, nil
|
return origData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PKCS7UnPadding 函数用于去除PKCS7填充的密文
|
||||||
func PKCS7UnPadding(plantText []byte) []byte {
|
func PKCS7UnPadding(plantText []byte) []byte {
|
||||||
|
// 获取密文的长度
|
||||||
length := len(plantText)
|
length := len(plantText)
|
||||||
|
// 如果密文长度大于0
|
||||||
if length > 0 {
|
if length > 0 {
|
||||||
|
// 获取最后一个字节的值,即填充的位数
|
||||||
unPadding := int(plantText[length-1])
|
unPadding := int(plantText[length-1])
|
||||||
|
// 返回去除填充后的密文
|
||||||
return plantText[:(length - unPadding)]
|
return plantText[:(length - unPadding)]
|
||||||
}
|
}
|
||||||
|
// 如果密文长度为0,则返回原密文
|
||||||
return plantText
|
return plantText
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue