package main import ( "embed" "encoding/json" "flag" "fmt" "io" "log" "net/http" "regexp" "otpm/auth" "otpm/config" "otpm/db" "github.com/gorilla/mux" _ "github.com/lib/pq" ) //go:embed init/sqlite/init.sql //go:embed init/postgresql/init.sql var initScripts embed.FS // 全局数据库连接 var tokenDB *db.DB // InitDB 初始化数据库连接 func InitDB(dbConfig config.DatabaseConfig) error { var err error // 创建数据库配置 dbCfg := db.Config{ Driver: dbConfig.Driver, } // 根据驱动类型设置配置 if dbConfig.Driver == "postgres" { dbCfg.Host = dbConfig.Postgres.Host dbCfg.Port = dbConfig.Postgres.Port dbCfg.User = dbConfig.Postgres.User dbCfg.Password = dbConfig.Postgres.Password dbCfg.DBName = dbConfig.Postgres.DBName dbCfg.SSLMode = dbConfig.Postgres.SSLMode } else if dbConfig.Driver == "sqlite3" { dbCfg.DBName = dbConfig.SQLite.Path } // 使用db包创建数据库连接 tokenDB, err = db.New(dbCfg) if err != nil { return fmt.Errorf("error opening database: %v", err) } // 读取并执行初始化脚本 var scriptPath string if dbConfig.Driver == "sqlite3" { scriptPath = "init/sqlite/init.sql" } else if dbConfig.Driver == "postgres" { scriptPath = "init/postgresql/init.sql" } else { return fmt.Errorf("unsupported database driver: %s", dbConfig.Driver) } script, err := initScripts.ReadFile(scriptPath) if err != nil { return fmt.Errorf("error reading init script: %v", err) } _, err = tokenDB.Exec(string(script)) if err != nil { return fmt.Errorf("error executing init script: %v", err) } return nil } // Server represents the API server type Server struct { config *config.Config router *mux.Router } type WechatLoginResponse struct { OpenID string `json:"openid"` SessionKey string `json:"session_key"` UnionID string `json:"unionid,omitempty"` ErrCode int `json:"errcode,omitempty"` ErrMsg string `json:"errmsg,omitempty"` } // NewServer creates a new instance of Server func NewServer(cfg *config.Config) *Server { s := &Server{ config: cfg, router: mux.NewRouter(), } s.setupRoutes() return s } // setupRoutes configures all the routes for the server func (s *Server) setupRoutes() { // 公开路由 s.router.HandleFunc("/auth/login", s.WechatLoginHandler) s.router.HandleFunc("/auth/refresh", s.RefreshTokenHandler) // 受保护路由(需要JWT) authRouter := s.router.PathPrefix("").Subrouter() authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey)) authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST") authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST") authRouter.HandleFunc("/otp/delete", DeleteTokenHandler).Methods("POST") authRouter.HandleFunc("/otp/clear_all", ClearAllHandler).Methods("POST") // 添加CORS中间件 s.router.Use(s.corsMiddleware) } // corsMiddleware handles CORS func (s *Server) corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") referer := r.Header.Get("Referer") userAgent := r.Header.Get("User-Agent") // 检查是否是微信小程序请求 isWechatMiniProgram := false if userAgent != "" && (regexp.MustCompile(`MicroMessenger`).MatchString(userAgent) || regexp.MustCompile(`miniProgram`).MatchString(userAgent)) { isWechatMiniProgram = true } // 从Referer中检查是否是微信小程序请求 if referer != "" && regexp.MustCompile(`^https://servicewechat\.com/`).MatchString(referer) { isWechatMiniProgram = true } // 检查Origin是否允许 allowed := isWechatMiniProgram // 如果是微信小程序请求,默认允许 if !allowed && origin != "" { for _, allowedOrigin := range s.config.CORS.AllowedOrigins { if origin == allowedOrigin || allowedOrigin == "*" { allowed = true break } } } if allowed { // 如果是微信小程序请求且没有Origin头,使用通配符 if isWechatMiniProgram && origin == "" { w.Header().Set("Access-Control-Allow-Origin", "*") } else { w.Header().Set("Access-Control-Allow-Origin", origin) } w.Header().Set("Access-Control-Allow-Methods", joinStrings(s.config.CORS.AllowedMethods)) w.Header().Set("Access-Control-Allow-Headers", joinStrings(s.config.CORS.AllowedHeaders)) w.Header().Set("Access-Control-Allow-Credentials", "true") } if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } // joinStrings joins string slice with commas func joinStrings(slice []string) string { result := "" for i, s := range slice { if i > 0 { result += ", " } result += s } return result } func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 读取请求体 body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } var req struct { Code string `json:"code"` } if err := json.Unmarshal(body, &req); err != nil { http.Error(w, "Invalid request", http.StatusBadRequest) return } // 向微信服务器请求session_key url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code) resp, err := http.Get(url) if err != nil { http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable) return } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { http.Error(w, "Wechat service error", http.StatusInternalServerError) return } var wechatResp WechatLoginResponse if err := json.Unmarshal(body, &wechatResp); err != nil { http.Error(w, "Wechat response parse error", http.StatusInternalServerError) return } if wechatResp.ErrCode != 0 { http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized) return } // 生成访问令牌和刷新令牌 accessToken, refreshToken, err := s.generateSessionTokens(wechatResp.OpenID) if err != nil { http.Error(w, "Failed to generate tokens", http.StatusInternalServerError) return } // 返回响应 response := map[string]interface{}{ "access_token": accessToken, "refresh_token": refreshToken, "openid": wechatResp.OpenID, } if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } } func (s *Server) generateSessionTokens(openid string) (accessToken string, refreshToken string, err error) { accessToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry) if err != nil { return "", "", fmt.Errorf("failed to generate access token: %v", err) } refreshToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.RefreshTokenExpiry) if err != nil { return "", "", fmt.Errorf("failed to generate refresh token: %v", err) } return accessToken, refreshToken, nil } func (s *Server) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 读取请求体 body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } var req struct { RefreshToken string `json:"refresh_token"` } if err := json.Unmarshal(body, &req); err != nil { http.Error(w, "Invalid request", http.StatusBadRequest) return } // 验证刷新令牌 claims, err := auth.ValidateToken(req.RefreshToken, s.config.Security.JWTSigningKey) if err != nil { http.Error(w, "Invalid refresh token", http.StatusUnauthorized) return } // 从刷新令牌中获取 openid openid := claims.UserID if openid == "" { http.Error(w, "Invalid token claims", http.StatusUnauthorized) return } // 生成新的访问令牌和刷新令牌 accessToken, refreshToken, err := s.generateSessionTokens(openid) if err != nil { http.Error(w, "Failed to generate tokens", http.StatusInternalServerError) return } // 返回响应 response := map[string]interface{}{ "access_token": accessToken, "refresh_token": refreshToken, "openid": openid, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } } // Start starts the HTTP server func (s *Server) Start() error { addr := fmt.Sprintf(":%d", s.config.Server.Port) log.Printf("Starting server on %s", addr) srv := &http.Server{ Handler: s.router, Addr: addr, WriteTimeout: s.config.Server.Timeout, ReadTimeout: s.config.Server.Timeout, } return srv.ListenAndServe() } func main() { // 定义命令行参数 configPath := flag.String("config", "config", "Path to configuration file (without extension)") flag.Parse() // 加载配置 cfg, err := config.LoadConfig(*configPath) if err != nil { log.Fatalf("Failed to load config: %v", err) } // 初始化数据库连接 log.Println("Initializing database connection...") if err := InitDB(cfg.Database); err != nil { log.Fatalf("Failed to initialize database: %v", err) } log.Println("Database connection established successfully") // 初始化API log.Println("Initializing API...") if err := InitAPI(cfg); err != nil { log.Fatalf("Failed to initialize API: %v", err) } log.Println("API initialized successfully") // 创建并启动服务器 server := NewServer(cfg) if err := server.Start(); err != nil { log.Fatalf("Server failed to start: %v", err) } }