otpm/api_server.go
2025-06-09 11:20:07 +08:00

202 lines
4.8 KiB
Go

package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"auth"
"config"
"github.com/gorilla/mux"
)
// Server represents the API server
type Server struct {
config *config.Config
router *mux.Router
}
type WechatLoginResponse struct {
OpenID string `json:"openid"`
SessionKey string `json:"session_key"`
UnionID string `json:"unionid,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
// NewServer creates a new instance of Server
func NewServer(cfg *config.Config) *Server {
s := &Server{
config: cfg,
router: mux.NewRouter(),
}
s.setupRoutes()
return s
}
// setupRoutes configures all the routes for the server
func (s *Server) setupRoutes() {
// 公开路由
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
// 受保护路由(需要JWT)
authRouter := s.router.PathPrefix("").Subrouter()
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
// 添加CORS中间件
s.router.Use(s.corsMiddleware)
}
// corsMiddleware handles CORS
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Check if the origin is allowed
allowed := false
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
if origin == allowedOrigin {
allowed = true
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods",
joinStrings(s.config.CORS.AllowedMethods))
w.Header().Set("Access-Control-Allow-Headers",
joinStrings(s.config.CORS.AllowedHeaders))
}
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// joinStrings joins string slice with commas
func joinStrings(slice []string) string {
result := ""
for i, s := range slice {
if i > 0 {
result += ", "
}
result += s
}
return result
}
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 读取请求体
body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
var req struct {
Code string `json:"code"`
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// 向微信服务器请求session_key
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
resp, err := http.Get(url)
if err != nil {
http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
body, err = ioutil.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Wechat service error", http.StatusInternalServerError)
return
}
var wechatResp WechatLoginResponse
if err := json.Unmarshal(body, &wechatResp); err != nil {
http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
return
}
if wechatResp.ErrCode != 0 {
http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
return
}
// 生成JWT token
token, err := s.generateSessionToken(wechatResp.OpenID)
if err != nil {
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
return
}
// 返回响应
response := map[string]interface{}{
"token": token,
"openid": wechatResp.OpenID,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
func (s *Server) generateSessionToken(openid string) (string, error) {
return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
}
// Start starts the HTTP server
func (s *Server) Start() error {
addr := fmt.Sprintf(":%d", s.config.Server.Port)
log.Printf("Starting server on %s", addr)
srv := &http.Server{
Handler: s.router,
Addr: addr,
WriteTimeout: s.config.Server.Timeout,
ReadTimeout: s.config.Server.Timeout,
}
return srv.ListenAndServe()
}
func main() {
// 加载配置
cfg, err := config.LoadConfig("config")
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化数据库连接
log.Println("Initializing database connection...")
if err := InitDB(cfg.Database); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
log.Println("Database connection established successfully")
// 创建并启动服务器
server := NewServer(cfg)
if err := server.Start(); err != nil {
log.Fatalf("Server failed to start: %v", err)
}
}