otpm/api_server.go
“xHuPo” 10ebc59ffb fix api
2025-06-17 14:46:09 +08:00

369 lines
9.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"embed"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"regexp"
"otpm/auth"
"otpm/config"
"otpm/db"
"github.com/gorilla/mux"
_ "github.com/lib/pq"
)
//go:embed init/sqlite/init.sql
//go:embed init/postgresql/init.sql
var initScripts embed.FS
// 全局数据库连接
var tokenDB *db.DB
// InitDB 初始化数据库连接
func InitDB(dbConfig config.DatabaseConfig) error {
var err error
// 创建数据库配置
dbCfg := db.Config{
Driver: dbConfig.Driver,
}
// 根据驱动类型设置配置
if dbConfig.Driver == "postgres" {
dbCfg.Host = dbConfig.Postgres.Host
dbCfg.Port = dbConfig.Postgres.Port
dbCfg.User = dbConfig.Postgres.User
dbCfg.Password = dbConfig.Postgres.Password
dbCfg.DBName = dbConfig.Postgres.DBName
dbCfg.SSLMode = dbConfig.Postgres.SSLMode
} else if dbConfig.Driver == "sqlite3" {
dbCfg.DBName = dbConfig.SQLite.Path
}
// 使用db包创建数据库连接
tokenDB, err = db.New(dbCfg)
if err != nil {
return fmt.Errorf("error opening database: %v", err)
}
// 读取并执行初始化脚本
var scriptPath string
if dbConfig.Driver == "sqlite3" {
scriptPath = "init/sqlite/init.sql"
} else if dbConfig.Driver == "postgres" {
scriptPath = "init/postgresql/init.sql"
} else {
return fmt.Errorf("unsupported database driver: %s", dbConfig.Driver)
}
script, err := initScripts.ReadFile(scriptPath)
if err != nil {
return fmt.Errorf("error reading init script: %v", err)
}
_, err = tokenDB.Exec(string(script))
if err != nil {
return fmt.Errorf("error executing init script: %v", err)
}
return nil
}
// Server represents the API server
type Server struct {
config *config.Config
router *mux.Router
}
type WechatLoginResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
// NewServer creates a new instance of Server
func NewServer(cfg *config.Config) *Server {
s := &Server{
config: cfg,
router: mux.NewRouter(),
}
s.setupRoutes()
return s
}
// setupRoutes configures all the routes for the server
func (s *Server) setupRoutes() {
// 公开路由
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
s.router.HandleFunc("/auth/refresh", s.RefreshTokenHandler)
// 受保护路由(需要JWT)
authRouter := s.router.PathPrefix("").Subrouter()
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
authRouter.HandleFunc("/otp/delete", DeleteTokenHandler).Methods("POST")
authRouter.HandleFunc("/otp/clear_all", ClearAllHandler).Methods("POST")
// 添加CORS中间件
s.router.Use(s.corsMiddleware)
}
// corsMiddleware handles CORS
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
referer := r.Header.Get("Referer")
userAgent := r.Header.Get("User-Agent")
// 检查是否是微信小程序请求
isWechatMiniProgram := false
if userAgent != "" && (regexp.MustCompile(`MicroMessenger`).MatchString(userAgent) ||
regexp.MustCompile(`miniProgram`).MatchString(userAgent)) {
isWechatMiniProgram = true
}
// 从Referer中检查是否是微信小程序请求
if referer != "" && regexp.MustCompile(`^https://servicewechat\.com/`).MatchString(referer) {
isWechatMiniProgram = true
}
// 检查Origin是否允许
allowed := isWechatMiniProgram // 如果是微信小程序请求,默认允许
if !allowed && origin != "" {
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
if origin == allowedOrigin || allowedOrigin == "*" {
allowed = true
break
}
}
}
if allowed {
// 如果是微信小程序请求且没有Origin头使用通配符
if isWechatMiniProgram && origin == "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods",
joinStrings(s.config.CORS.AllowedMethods))
w.Header().Set("Access-Control-Allow-Headers",
joinStrings(s.config.CORS.AllowedHeaders))
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// joinStrings joins string slice with commas
func joinStrings(slice []string) string {
result := ""
for i, s := range slice {
if i > 0 {
result += ", "
}
result += s
}
return result
}
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 读取请求体
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
var req struct {
Code string `json:"code"`
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// 向微信服务器请求session_key
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
resp, err := http.Get(url)
if err != nil {
http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Wechat service error", http.StatusInternalServerError)
return
}
var wechatResp WechatLoginResponse
if err := json.Unmarshal(body, &wechatResp); err != nil {
http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
return
}
if wechatResp.ErrCode != 0 {
http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
return
}
// 生成访问令牌和刷新令牌
accessToken, refreshToken, err := s.generateSessionTokens(wechatResp.OpenID)
if err != nil {
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
return
}
// 返回响应
response := map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshToken,
"openid": wechatResp.OpenID,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
func (s *Server) generateSessionTokens(openid string) (accessToken string, refreshToken string, err error) {
accessToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
if err != nil {
return "", "", fmt.Errorf("failed to generate access token: %v", err)
}
refreshToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.RefreshTokenExpiry)
if err != nil {
return "", "", fmt.Errorf("failed to generate refresh token: %v", err)
}
return accessToken, refreshToken, nil
}
func (s *Server) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 读取请求体
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
var req struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// 验证刷新令牌
claims, err := auth.ValidateToken(req.RefreshToken, s.config.Security.JWTSigningKey)
if err != nil {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
// 从刷新令牌中获取 openid
openid := claims.UserID
if openid == "" {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
// 生成新的访问令牌和刷新令牌
accessToken, refreshToken, err := s.generateSessionTokens(openid)
if err != nil {
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
return
}
// 返回响应
response := map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshToken,
"openid": openid,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
// Start starts the HTTP server
func (s *Server) Start() error {
addr := fmt.Sprintf(":%d", s.config.Server.Port)
log.Printf("Starting server on %s", addr)
srv := &http.Server{
Handler: s.router,
Addr: addr,
WriteTimeout: s.config.Server.Timeout,
ReadTimeout: s.config.Server.Timeout,
}
return srv.ListenAndServe()
}
func main() {
// 定义命令行参数
configPath := flag.String("config", "config", "Path to configuration file (without extension)")
flag.Parse()
// 加载配置
cfg, err := config.LoadConfig(*configPath)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化数据库连接
log.Println("Initializing database connection...")
if err := InitDB(cfg.Database); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
log.Println("Database connection established successfully")
// 初始化API
log.Println("Initializing API...")
if err := InitAPI(cfg); err != nil {
log.Fatalf("Failed to initialize API: %v", err)
}
log.Println("API initialized successfully")
// 创建并启动服务器
server := NewServer(cfg)
if err := server.Start(); err != nil {
log.Fatalf("Server failed to start: %v", err)
}
}