add branch v1
This commit is contained in:
parent
5d370e1077
commit
01b8951dd5
53 changed files with 1079 additions and 6481 deletions
202
api_server.go
Normal file
202
api_server.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"auth"
|
||||
"config"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Server represents the API server
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
type WechatLoginResponse struct {
|
||||
OpenID string `json:"openid"`
|
||||
SessionKey string `json:"session_key"`
|
||||
UnionID string `json:"unionid,omitempty"`
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
// NewServer creates a new instance of Server
|
||||
func NewServer(cfg *config.Config) *Server {
|
||||
s := &Server{
|
||||
config: cfg,
|
||||
router: mux.NewRouter(),
|
||||
}
|
||||
s.setupRoutes()
|
||||
return s
|
||||
}
|
||||
|
||||
// setupRoutes configures all the routes for the server
|
||||
func (s *Server) setupRoutes() {
|
||||
// 公开路由
|
||||
s.router.HandleFunc("/auth/login", s.WechatLoginHandler)
|
||||
|
||||
// 受保护路由(需要JWT)
|
||||
authRouter := s.router.PathPrefix("").Subrouter()
|
||||
authRouter.Use(auth.NewAuthMiddleware(s.config.Security.JWTSigningKey))
|
||||
authRouter.HandleFunc("/otp/save", SaveHandler).Methods("POST")
|
||||
authRouter.HandleFunc("/otp/recover", RecoverHandler).Methods("POST")
|
||||
|
||||
// 添加CORS中间件
|
||||
s.router.Use(s.corsMiddleware)
|
||||
}
|
||||
|
||||
// corsMiddleware handles CORS
|
||||
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// Check if the origin is allowed
|
||||
allowed := false
|
||||
for _, allowedOrigin := range s.config.CORS.AllowedOrigins {
|
||||
if origin == allowedOrigin {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods",
|
||||
joinStrings(s.config.CORS.AllowedMethods))
|
||||
w.Header().Set("Access-Control-Allow-Headers",
|
||||
joinStrings(s.config.CORS.AllowedHeaders))
|
||||
}
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// joinStrings joins string slice with commas
|
||||
func joinStrings(slice []string) string {
|
||||
result := ""
|
||||
for i, s := range slice {
|
||||
if i > 0 {
|
||||
result += ", "
|
||||
}
|
||||
result += s
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Server) WechatLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 向微信服务器请求session_key
|
||||
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
||||
s.config.Wechat.AppID, s.config.Wechat.AppSecret, req.Code)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
http.Error(w, "Wechat service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Wechat service error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var wechatResp WechatLoginResponse
|
||||
if err := json.Unmarshal(body, &wechatResp); err != nil {
|
||||
http.Error(w, "Wechat response parse error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if wechatResp.ErrCode != 0 {
|
||||
http.Error(w, wechatResp.ErrMsg, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := s.generateSessionToken(wechatResp.OpenID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := map[string]interface{}{
|
||||
"token": token,
|
||||
"openid": wechatResp.OpenID,
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) generateSessionToken(openid string) (string, error) {
|
||||
return auth.GenerateToken(openid, s.config.Security.JWTSigningKey, s.config.Security.TokenExpiry)
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *Server) Start() error {
|
||||
addr := fmt.Sprintf(":%d", s.config.Server.Port)
|
||||
log.Printf("Starting server on %s", addr)
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: s.router,
|
||||
Addr: addr,
|
||||
WriteTimeout: s.config.Server.Timeout,
|
||||
ReadTimeout: s.config.Server.Timeout,
|
||||
}
|
||||
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 加载配置
|
||||
cfg, err := config.LoadConfig("config")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
log.Println("Initializing database connection...")
|
||||
if err := InitDB(cfg.Database); err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
log.Println("Database connection established successfully")
|
||||
|
||||
// 创建并启动服务器
|
||||
server := NewServer(cfg)
|
||||
if err := server.Start(); err != nil {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue