369 lines
9.7 KiB
Go
369 lines
9.7 KiB
Go
package main
|
||
|
||
import (
|
||
"embed"
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"regexp"
|
||
|
||
"otpm/auth"
|
||
"otpm/config"
|
||
"otpm/db"
|
||
|
||
"github.com/gorilla/mux"
|
||
_ "github.com/lib/pq"
|
||
)
|
||
|
||
//go:embed init/sqlite/init.sql
|
||
//go:embed init/postgresql/init.sql
|
||
var initScripts embed.FS
|
||
|
||
// 全局数据库连接
|
||
var tokenDB *db.DB
|
||
|
||
// InitDB 初始化数据库连接
|
||
func InitDB(dbConfig config.DatabaseConfig) error {
|
||
var err error
|
||
|
||
// 创建数据库配置
|
||
dbCfg := db.Config{
|
||
Driver: dbConfig.Driver,
|
||
}
|
||
|
||
// 根据驱动类型设置配置
|
||
if dbConfig.Driver == "postgres" {
|
||
dbCfg.Host = dbConfig.Postgres.Host
|
||
dbCfg.Port = dbConfig.Postgres.Port
|
||
dbCfg.User = dbConfig.Postgres.User
|
||
dbCfg.Password = dbConfig.Postgres.Password
|
||
dbCfg.DBName = dbConfig.Postgres.DBName
|
||
dbCfg.SSLMode = dbConfig.Postgres.SSLMode
|
||
} else if dbConfig.Driver == "sqlite3" {
|
||
dbCfg.DBName = dbConfig.SQLite.Path
|
||
}
|
||
|
||
// 使用db包创建数据库连接
|
||
tokenDB, err = db.New(dbCfg)
|
||
if err != nil {
|
||
return fmt.Errorf("error opening database: %v", err)
|
||
}
|
||
|
||
// 读取并执行初始化脚本
|
||
var scriptPath string
|
||
if dbConfig.Driver == "sqlite3" {
|
||
scriptPath = "init/sqlite/init.sql"
|
||
} else if dbConfig.Driver == "postgres" {
|
||
scriptPath = "init/postgresql/init.sql"
|
||
} else {
|
||
return fmt.Errorf("unsupported database driver: %s", dbConfig.Driver)
|
||
}
|
||
|
||
script, err := initScripts.ReadFile(scriptPath)
|
||
if err != nil {
|
||
return fmt.Errorf("error reading init script: %v", err)
|
||
}
|
||
|
||
_, err = tokenDB.Exec(string(script))
|
||
if err != nil {
|
||
return fmt.Errorf("error executing init script: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Server represents the API server
|
||
type Server struct {
|
||
config *config.Config
|
||
router *mux.Router
|
||
}
|
||
|
||
type WechatLoginResponse struct {
|
||
OpenID string `json:"openid"`
|
||
SessionKey string `json:"session_key"`
|
||
UnionID string `json:"unionid,omitempty"`
|
||
ErrCode int `json:"errcode,omitempty"`
|
||
ErrMsg string `json:"errmsg,omitempty"`
|
||
}
|
||
|
||
// NewServer creates a new instance of Server
|
||
func NewServer(cfg *config.Config) *Server {
|
||
s := &Server{
|
||
config: cfg,
|
||
router: mux.NewRouter(),
|
||
}
|
||
|
||
s.setupRoutes()
|
||
return s
|
||
}
|
||
|
||
// setupRoutes configures all the routes for the server
|
||
func (s *Server) setupRoutes() {
|
||
// 公开路由
|
||
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
|
||
s.router.HandleFunc("/auth/refresh", s.RefreshTokenHandler)
|
||
|
||
// 受保护路由(需要JWT)
|
||
authRouter := s.router.PathPrefix("").Subrouter()
|
||
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
|
||
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
|
||
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
|
||
authRouter.HandleFunc("/otp/delete", DeleteTokenHandler).Methods("POST")
|
||
authRouter.HandleFunc("/otp/clear_all", ClearAllHandler).Methods("POST")
|
||
|
||
// 添加CORS中间件
|
||
s.router.Use(s.corsMiddleware)
|
||
}
|
||
|
||
// corsMiddleware handles CORS
|
||
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
origin := r.Header.Get("Origin")
|
||
referer := r.Header.Get("Referer")
|
||
userAgent := r.Header.Get("User-Agent")
|
||
|
||
// 检查是否是微信小程序请求
|
||
isWechatMiniProgram := false
|
||
if userAgent != "" && (regexp.MustCompile(`MicroMessenger`).MatchString(userAgent) ||
|
||
regexp.MustCompile(`miniProgram`).MatchString(userAgent)) {
|
||
isWechatMiniProgram = true
|
||
}
|
||
|
||
// 从Referer中检查是否是微信小程序请求
|
||
if referer != "" && regexp.MustCompile(`^https://servicewechat\.com/`).MatchString(referer) {
|
||
isWechatMiniProgram = true
|
||
}
|
||
|
||
// 检查Origin是否允许
|
||
allowed := isWechatMiniProgram // 如果是微信小程序请求,默认允许
|
||
if !allowed && origin != "" {
|
||
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
|
||
if origin == allowedOrigin || allowedOrigin == "*" {
|
||
allowed = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if allowed {
|
||
// 如果是微信小程序请求且没有Origin头,使用通配符
|
||
if isWechatMiniProgram && origin == "" {
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
} else {
|
||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||
}
|
||
|
||
w.Header().Set("Access-Control-Allow-Methods",
|
||
joinStrings(s.config.CORS.AllowedMethods))
|
||
w.Header().Set("Access-Control-Allow-Headers",
|
||
joinStrings(s.config.CORS.AllowedHeaders))
|
||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||
}
|
||
|
||
if r.Method == "OPTIONS" {
|
||
w.WriteHeader(http.StatusOK)
|
||
return
|
||
}
|
||
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
// joinStrings joins string slice with commas
|
||
func joinStrings(slice []string) string {
|
||
result := ""
|
||
for i, s := range slice {
|
||
if i > 0 {
|
||
result += ", "
|
||
}
|
||
result += s
|
||
}
|
||
return result
|
||
}
|
||
|
||
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 读取请求体
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
Code string `json:"code"`
|
||
}
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 向微信服务器请求session_key
|
||
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
||
s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
|
||
|
||
resp, err := http.Get(url)
|
||
if err != nil {
|
||
http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
http.Error(w, "Wechat service error", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
var wechatResp WechatLoginResponse
|
||
if err := json.Unmarshal(body, &wechatResp); err != nil {
|
||
http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
if wechatResp.ErrCode != 0 {
|
||
http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 生成访问令牌和刷新令牌
|
||
accessToken, refreshToken, err := s.generateSessionTokens(wechatResp.OpenID)
|
||
if err != nil {
|
||
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 返回响应
|
||
response := map[string]interface{}{
|
||
"access_token": accessToken,
|
||
"refresh_token": refreshToken,
|
||
"openid": wechatResp.OpenID,
|
||
}
|
||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
func (s *Server) generateSessionTokens(openid string) (accessToken string, refreshToken string, err error) {
|
||
accessToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("failed to generate access token: %v", err)
|
||
}
|
||
|
||
refreshToken, err = auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.RefreshTokenExpiry)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("failed to generate refresh token: %v", err)
|
||
}
|
||
|
||
return accessToken, refreshToken, nil
|
||
}
|
||
|
||
func (s *Server) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 读取请求体
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
RefreshToken string `json:"refresh_token"`
|
||
}
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 验证刷新令牌
|
||
claims, err := auth.ValidateToken(req.RefreshToken, s.config.Security.JWTSigningKey)
|
||
if err != nil {
|
||
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 从刷新令牌中获取 openid
|
||
openid := claims.UserID
|
||
if openid == "" {
|
||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 生成新的访问令牌和刷新令牌
|
||
accessToken, refreshToken, err := s.generateSessionTokens(openid)
|
||
if err != nil {
|
||
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 返回响应
|
||
response := map[string]interface{}{
|
||
"access_token": accessToken,
|
||
"refresh_token": refreshToken,
|
||
"openid": openid,
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
// Start starts the HTTP server
|
||
func (s *Server) Start() error {
|
||
addr := fmt.Sprintf(":%d", s.config.Server.Port)
|
||
log.Printf("Starting server on %s", addr)
|
||
|
||
srv := &http.Server{
|
||
Handler: s.router,
|
||
Addr: addr,
|
||
WriteTimeout: s.config.Server.Timeout,
|
||
ReadTimeout: s.config.Server.Timeout,
|
||
}
|
||
|
||
return srv.ListenAndServe()
|
||
}
|
||
|
||
func main() {
|
||
// 定义命令行参数
|
||
configPath := flag.String("config", "config", "Path to configuration file (without extension)")
|
||
flag.Parse()
|
||
|
||
// 加载配置
|
||
cfg, err := config.LoadConfig(*configPath)
|
||
if err != nil {
|
||
log.Fatalf("Failed to load config: %v", err)
|
||
}
|
||
|
||
// 初始化数据库连接
|
||
log.Println("Initializing database connection...")
|
||
if err := InitDB(cfg.Database); err != nil {
|
||
log.Fatalf("Failed to initialize database: %v", err)
|
||
}
|
||
log.Println("Database connection established successfully")
|
||
|
||
// 初始化API
|
||
log.Println("Initializing API...")
|
||
if err := InitAPI(cfg); err != nil {
|
||
log.Fatalf("Failed to initialize API: %v", err)
|
||
}
|
||
log.Println("API initialized successfully")
|
||
|
||
// 创建并启动服务器
|
||
server := NewServer(cfg)
|
||
if err := server.Start(); err != nil {
|
||
log.Fatalf("Server failed to start: %v", err)
|
||
}
|
||
}
|