package main import ( "encoding/json" "fmt" "io/ioutil" "log" "net/http" "auth" "config" "github.com/gorilla/mux" ) // Server represents the API server type Server struct { config *config.Config router *mux.Router } type WechatLoginResponse struct { OpenID string `json:"openid"` SessionKey string `json:"session_key"` UnionID string `json:"unionid,omitempty"` ErrCode int `json:"errcode,omitempty"` ErrMsg string `json:"errmsg,omitempty"` } // NewServer creates a new instance of Server func NewServer(cfg *config.Config) *Server { s := &Server{ config: cfg, router: mux.NewRouter(), } s.setupRoutes() return s } // setupRoutes configures all the routes for the server func (s *Server) setupRoutes() { // 公开路由 s.router.HandleFunc("/auth/login", s.WechatLoginHandler) // 受保护路由(需要JWT) authRouter := s.router.PathPrefix("").Subrouter() authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey)) authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST") authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST") // 添加CORS中间件 s.router.Use(s.corsMiddleware) } // corsMiddleware handles CORS func (s *Server) corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Check if the origin is allowed allowed := false for _, allowedOrigin := range s.config.CORS.AllowedOrigins { if origin == allowedOrigin { allowed = true break } } if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", joinStrings(s.config.CORS.AllowedMethods)) w.Header().Set("Access-Control-Allow-Headers", joinStrings(s.config.CORS.AllowedHeaders)) } if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } // joinStrings joins string slice with commas func joinStrings(slice []string) string { result := "" for i, s := range slice { if i > 0 { result += ", " } result += s } return result } func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 读取请求体 body, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } var req struct { Code string `json:"code"` } if err := json.Unmarshal(body, &req); err != nil { http.Error(w, "Invalid request", http.StatusBadRequest) return } // 向微信服务器请求session_key url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code) resp, err := http.Get(url) if err != nil { http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable) return } defer resp.Body.Close() body, err = ioutil.ReadAll(resp.Body) if err != nil { http.Error(w, "Wechat service error", http.StatusInternalServerError) return } var wechatResp WechatLoginResponse if err := json.Unmarshal(body, &wechatResp); err != nil { http.Error(w, "Wechat response parse error", http.StatusInternalServerError) return } if wechatResp.ErrCode != 0 { http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized) return } // 生成JWT token token, err := s.generateSessionToken(wechatResp.OpenID) if err != nil { http.Error(w, "Failed to generate token", http.StatusInternalServerError) return } // 返回响应 response := map[string]interface{}{ "token": token, "openid": wechatResp.OpenID, } if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } } func (s *Server) generateSessionToken(openid string) (string, error) { return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry) } // Start starts the HTTP server func (s *Server) Start() error { addr := fmt.Sprintf(":%d", s.config.Server.Port) log.Printf("Starting server on %s", addr) srv := &http.Server{ Handler: s.router, Addr: addr, WriteTimeout: s.config.Server.Timeout, ReadTimeout: s.config.Server.Timeout, } return srv.ListenAndServe() } func main() { // 加载配置 cfg, err := config.LoadConfig("config") if err != nil { log.Fatalf("Failed to load config: %v", err) } // 初始化数据库连接 log.Println("Initializing database connection...") if err := InitDB(cfg.Database); err != nil { log.Fatalf("Failed to initialize database: %v", err) } log.Println("Database connection established successfully") // 创建并启动服务器 server := NewServer(cfg) if err := server.Start(); err != nil { log.Fatalf("Server failed to start: %v", err) } }