202 lines
4.8 KiB
Go
202 lines
4.8 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
|
|
"auth"
|
|
"config"
|
|
|
|
"github.com/gorilla/mux"
|
|
)
|
|
|
|
// Server represents the API server
|
|
type Server struct {
|
|
config *config.Config
|
|
router *mux.Router
|
|
}
|
|
|
|
type WechatLoginResponse struct {
|
|
OpenID string `json:"openid"`
|
|
SessionKey string `json:"session_key"`
|
|
UnionID string `json:"unionid,omitempty"`
|
|
ErrCode int `json:"errcode,omitempty"`
|
|
ErrMsg string `json:"errmsg,omitempty"`
|
|
}
|
|
|
|
// NewServer creates a new instance of Server
|
|
func NewServer(cfg *config.Config) *Server {
|
|
s := &Server{
|
|
config: cfg,
|
|
router: mux.NewRouter(),
|
|
}
|
|
s.setupRoutes()
|
|
return s
|
|
}
|
|
|
|
// setupRoutes configures all the routes for the server
|
|
func (s *Server) setupRoutes() {
|
|
// 公开路由
|
|
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
|
|
|
|
// 受保护路由(需要JWT)
|
|
authRouter := s.router.PathPrefix("").Subrouter()
|
|
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
|
|
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
|
|
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
|
|
|
|
// 添加CORS中间件
|
|
s.router.Use(s.corsMiddleware)
|
|
}
|
|
|
|
// corsMiddleware handles CORS
|
|
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
// Check if the origin is allowed
|
|
allowed := false
|
|
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
|
|
if origin == allowedOrigin {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if allowed {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Methods",
|
|
joinStrings(s.config.CORS.AllowedMethods))
|
|
w.Header().Set("Access-Control-Allow-Headers",
|
|
joinStrings(s.config.CORS.AllowedHeaders))
|
|
}
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// joinStrings joins string slice with commas
|
|
func joinStrings(slice []string) string {
|
|
result := ""
|
|
for i, s := range slice {
|
|
if i > 0 {
|
|
result += ", "
|
|
}
|
|
result += s
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// 读取请求体
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Code string `json:"code"`
|
|
}
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
http.Error(w, "Invalid request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 向微信服务器请求session_key
|
|
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
|
s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
|
|
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err = ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
http.Error(w, "Wechat service error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var wechatResp WechatLoginResponse
|
|
if err := json.Unmarshal(body, &wechatResp); err != nil {
|
|
http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if wechatResp.ErrCode != 0 {
|
|
http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 生成JWT token
|
|
token, err := s.generateSessionToken(wechatResp.OpenID)
|
|
if err != nil {
|
|
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 返回响应
|
|
response := map[string]interface{}{
|
|
"token": token,
|
|
"openid": wechatResp.OpenID,
|
|
}
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (s *Server) generateSessionToken(openid string) (string, error) {
|
|
return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
|
|
}
|
|
|
|
// Start starts the HTTP server
|
|
func (s *Server) Start() error {
|
|
addr := fmt.Sprintf(":%d", s.config.Server.Port)
|
|
log.Printf("Starting server on %s", addr)
|
|
|
|
srv := &http.Server{
|
|
Handler: s.router,
|
|
Addr: addr,
|
|
WriteTimeout: s.config.Server.Timeout,
|
|
ReadTimeout: s.config.Server.Timeout,
|
|
}
|
|
|
|
return srv.ListenAndServe()
|
|
}
|
|
|
|
func main() {
|
|
// 加载配置
|
|
cfg, err := config.LoadConfig("config")
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
// 初始化数据库连接
|
|
log.Println("Initializing database connection...")
|
|
if err := InitDB(cfg.Database); err != nil {
|
|
log.Fatalf("Failed to initialize database: %v", err)
|
|
}
|
|
log.Println("Database connection established successfully")
|
|
|
|
// 创建并启动服务器
|
|
server := NewServer(cfg)
|
|
if err := server.Start(); err != nil {
|
|
log.Fatalf("Server failed to start: %v", err)
|
|
}
|
|
}
|