From 079542e43166694bfb2503b1521282ce8078a634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CxHuPo=E2=80=9D?= <7513325+vrocwang@users.noreply.github.com> Date: Thu, 22 May 2025 12:06:34 +0800 Subject: [PATCH] alpha --- cmd/root.go | 88 ++++++++++++++++++++++++++++++++++++++----- config.yaml | 5 +++ database/database.go | 15 ++++---- database/init/otp.sql | 3 +- go.mod | 1 + go.sum | 2 + handlers/handler.go | 18 +++++++++ handlers/login.go | 83 +++++++++++++++++++++++++++------------- handlers/otp.go | 33 ++++++++++------ 9 files changed, 191 insertions(+), 57 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 3752544..dbf3e25 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,13 +1,18 @@ package cmd import ( + "context" + "errors" "fmt" "log" "net/http" "os" + "os/signal" "otpm/database" "otpm/handlers" "otpm/utils" + "syscall" + "time" "github.com/jmoiron/sqlx" "github.com/julienschmidt/httprouter" @@ -57,20 +62,50 @@ func initConfig() { } } -func initApp(db *sqlx.DB) { - if err := database.MigrateDB(db); err != nil { - log.Fatalf("Error migrating the database: %v", err) - } +type App struct { + db *sqlx.DB + router http.Handler + port int } -func startApp() { +func NewApp() (*App, error) { + db, err := connectDB() + if err != nil { + return nil, err + } + + if err := runMigrations(db); err != nil { + return nil, err + } + + router := setupRouter(db) + port := viper.GetInt("port") + + return &App{ + db: db, + router: router, + port: port, + }, nil +} + +func connectDB() (*sqlx.DB, error) { db, err := database.InitDB() if err != nil { - log.Fatalf("Error connecting to the database: %v", err) + return nil, fmt.Errorf("failed to connect to the database: %v", err) } - defer db.Close() - initApp(db) + return db, nil +} + +func runMigrations(db *sqlx.DB) error { + if err := database.MigrateDB(db); err != nil { + log.Fatalf("Error migrating the database: %v", err) + return fmt.Errorf("error migrating the database: %w", err) + } + return nil +} + +func setupRouter(db *sqlx.DB) http.Handler { handler := &handlers.Handler{DB: db} router := httprouter.New() @@ -78,6 +113,39 @@ func startApp() { router.POST("/set", utils.AdaptHandler(handler.UpdateOrCreateOtp)) router.GET("/get", utils.AdaptHandler(handler.GetOtp)) - log.Println("Starting server on :8080") - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), router)) + return router +} + +func (a *App) Start() error { + server := &http.Server{Addr: fmt.Sprintf(":%d", a.port), Handler: a.router} + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Failed to start server: %v", err) + } + }() + + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + return server.Shutdown(ctx) +} + +func startApp() { + app, err := NewApp() + if err != nil { + log.Fatalf("Failed to initialize application: %v", err) + } + defer func() { + if err := app.db.Close(); err != nil { + log.Printf("Failed to close database connection: %v", err) + } + }() + if err := app.Start(); err != nil { + log.Fatalf("Failed to start application: %v", err) + } } diff --git a/config.yaml b/config.yaml index f23f8fb..a61e877 100644 --- a/config.yaml +++ b/config.yaml @@ -1,8 +1,13 @@ +server: + name: "otpm" database: driver: sqlite dsn: otpm.sqlite port: 8080 +auth: + secret: "secret" + ttl: 3600 wechat: appid: "wx57d1033974eb5250" secret: "be494c2a81df685a40b9a74e1736b15d" diff --git a/database/database.go b/database/database.go index bf4ade8..be8aa03 100644 --- a/database/database.go +++ b/database/database.go @@ -2,6 +2,7 @@ package database import ( _ "embed" + "fmt" "log" _ "github.com/go-sql-driver/mysql" @@ -24,11 +25,11 @@ func InitDB() (*sqlx.DB, error) { db, err := sqlx.Open(driver, dsn) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to connect to database: %w", err) } if err := db.Ping(); err != nil { - return nil, err + return nil, fmt.Errorf("failed to ping database: %w", err) } log.Println("Connected to database!") @@ -36,14 +37,12 @@ func InitDB() (*sqlx.DB, error) { } func MigrateDB(db *sqlx.DB) error { - _, err := db.Exec(userTable) - if err != nil { - return err + if _, err := db.Exec(userTable); err != nil { + return fmt.Errorf("failed to create user migration: %w", err) } - _, err = db.Exec(otpTable) - if err != nil { - return err + if _, err := db.Exec(otpTable); err != nil { + return fmt.Errorf("failed to create otp migration: %w", err) } return nil } diff --git a/database/init/otp.sql b/database/init/otp.sql index a9b2442..db2e481 100644 --- a/database/init/otp.sql +++ b/database/init/otp.sql @@ -2,5 +2,6 @@ CREATE TABLE IF NOT EXISTS otp ( id SERIAL PRIMARY KEY, openid VARCHAR(255), num INTEGER, - token VARCHAR(255) + token VARCHAR(255), + createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); \ No newline at end of file diff --git a/go.mod b/go.mod index 57bb068..b0da37e 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/hcl v1.0.0 // indirect diff --git a/go.sum b/go.sum index 5dc30bc..1c01a15 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= diff --git a/handlers/handler.go b/handlers/handler.go index db15938..96e97c0 100644 --- a/handlers/handler.go +++ b/handlers/handler.go @@ -1,9 +1,27 @@ package handlers import ( + "encoding/json" + "net/http" + "github.com/jmoiron/sqlx" ) type Handler struct { DB *sqlx.DB } + +type Response struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +func WriteJSON(w http.ResponseWriter, data interface{}, code int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(data) +} +func WriteError(w http.ResponseWriter, message string, code int) { + WriteJSON(w, Response{Code: code, Message: message}, code) +} diff --git a/handlers/login.go b/handlers/login.go index 30812d7..ca41147 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -5,11 +5,18 @@ import ( "fmt" "io" "net/http" + "time" + "github.com/golang-jwt/jwt" "github.com/spf13/viper" ) -var code2Session = "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code" +var wxClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + }, +} type LoginRequest struct { Code string `json:"code"` @@ -27,8 +34,8 @@ type LoginResponse struct { func getLoginResponse(code string) (*LoginResponse, error) { appid := viper.GetString("wechat.appid") secret := viper.GetString("wechat.secret") - url := fmt.Sprintf(code2Session, appid, secret, code) - resp, err := http.Get(url) + url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code) + resp, err := wxClient.Get(url) if err != nil { return nil, err } @@ -45,49 +52,71 @@ func getLoginResponse(code string) (*LoginResponse, error) { return &loginResponse, nil } -func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { +func generateJWT(openid string) (string, error) { + tokenTTL := viper.GetDuration("auth.ttl") + secret := viper.GetString("auth.secret") + if secret == "" { + return "", fmt.Errorf("auth secret not set") + } + claims := jwt.MapClaims{ + "openid": openid, + "exp": time.Now().Add(tokenTTL).Unix(), + "iat": time.Now().Unix(), + "iss": viper.GetString("server.name"), + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString([]byte(secret)) + if err != nil { + return "", err + } + return signedToken, nil +} + +func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { var req LoginRequest body, err := io.ReadAll(r.Body) if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) + WriteError(w, "Failed to read request body", http.StatusBadRequest) return } if err := json.Unmarshal(body, &req); err != nil { - http.Error(w, "Failed to parse request body", http.StatusBadRequest) + WriteError(w, "Failed to parse request body", http.StatusBadRequest) return } loginResponse, err := getLoginResponse(req.Code) if err != nil { - http.Error(w, "Failed to get session key", http.StatusInternalServerError) + WriteError(w, "Failed to get login response", http.StatusInternalServerError) return } - // // 插入或更新用户的openid和sessionid - // query := ` - // INSERT INTO users (openid, sessionid) - // VALUES ($1, $2) - // ON CONFLICT (openid) DO UPDATE SET sessionid = $2 - // RETURNING id; - // ` + // 插入或更新用户的openid和session_key + query := ` + INSERT INTO users (openid, session_key) + VALUES ($1, $2) + ON CONFLICT (openid) DO UPDATE SET session_key = $2 + RETURNING id; + ` - // var ID int - // if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil { - // http.Error(w, "Failed to log in user", http.StatusInternalServerError) - // return - // } + var ID int + if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil { + WriteError(w, "Failed to log in user", http.StatusInternalServerError) + return + } + + token, err := generateJWT(loginResponse.OpenId) + if err != nil { + WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError) + return + } data := map[string]interface{}{ + "token": token, "openid": loginResponse.OpenId, "session_key": loginResponse.SessionKey, } - respData, err := json.Marshal(data) - if err != nil { - http.Error(w, "Failed to marshal response data", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(respData)) + + WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK) } diff --git a/handlers/otp.go b/handlers/otp.go index 61230ec..9895e76 100644 --- a/handlers/otp.go +++ b/handlers/otp.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "log" "net/http" ) @@ -19,43 +20,53 @@ type OTP struct { func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) { var req OtpRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request payload", http.StatusBadRequest) + WriteError(w, "Failed to parse request body", http.StatusBadRequest) return } - num := len(*req.Token) + if req.OpenID == "" { + WriteError(w, "OpenID is required", http.StatusBadRequest) + return + } + + if req.Token == nil || len(*req.Token) == 0 { + WriteError(w, "Token is required", http.StatusBadRequest) + return + } + + log.Printf("Saving OTP for user: %s token count:: %d", req.OpenID, len(*req.Token)) // 插入或更新 OTP 记录 query := ` INSERT INTO otp (openid, num, token) VALUES ($1, $2, $3) + ON CONFLICT (openid) DO UPDATE SET num = EXCLUDED.num, token = EXCLUDED.token + } ` - _, err := h.DB.Exec(query, req.OpenID, req.Token, num) + _, err := h.DB.Exec(query, req.OpenID, len(*req.Token), req.Token) if err != nil { - http.Error(w, "Failed to update or create OTP", http.StatusInternalServerError) + WriteError(w, "Failed to update or create OTP", http.StatusInternalServerError) return } - w.WriteHeader(http.StatusOK) - w.Write([]byte("OTP updated or created successfully")) + WriteJSON(w, Response{Code: 0, Message: "Success"}, http.StatusOK) } func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) { openid := r.URL.Query().Get("openid") if openid == "" { - http.Error(w, "未登录", http.StatusBadRequest) + WriteError(w, "OpenID is required", http.StatusBadRequest) return } var otp OtpRequest - err := h.DB.Get(&otp, "SELECT token, num, openid FROM otp WHERE openid=$1", openid) + err := h.DB.Get(&otp, "SELECT openid, token, num FROM otp WHERE openid=$1", openid) if err != nil { - http.Error(w, "OTP not found", http.StatusNotFound) + WriteError(w, "Failed to get OTP", http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(otp) + WriteJSON(w, Response{Code: 0, Message: "Success", Data: otp}, http.StatusOK) }