This commit is contained in:
“xHuPo” 2025-05-23 18:57:11 +08:00
parent a45ddf13d5
commit bcd986e3f7
46 changed files with 6166 additions and 454 deletions

147
handlers/auth_handler.go Normal file
View file

@ -0,0 +1,147 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"otpm/api"
"otpm/services"
"github.com/golang-jwt/jwt"
)
// AuthHandler handles authentication related requests
type AuthHandler struct {
authService *services.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code"`
}
// LoginResponse represents a login response
type LoginResponse struct {
Token string `json:"token"`
OpenID string `json:"openid"`
}
// Login handles WeChat login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Limit request body size to prevent DOS
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
// Parse request
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("Login request parse error: %v", err)
return
}
// Validate request
if req.Code == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Code is required")
log.Printf("Login request validation failed: empty code")
return
}
// Login with WeChat code
token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("Login failed for code %s: %v", req.Code, err)
return
}
// Log successful login
log.Printf("Login successful for code %s (took %v)",
req.Code, time.Since(start))
// Return token
api.NewResponseWriter(w).WriteSuccess(LoginResponse{
Token: token,
})
}
// VerifyToken handles token verification
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Get token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Authorization header is required")
log.Printf("Token verification failed: missing Authorization header")
return
}
// Validate token format
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token format. Expected 'Bearer <token>'")
log.Printf("Token verification failed: invalid token format")
return
}
token := authHeader[7:]
if len(token) < 32 { // Basic length check
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token length")
log.Printf("Token verification failed: token too short")
return
}
// Validate token
claims, err := h.authService.ValidateToken(token)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("Token verification failed for token %s: %v",
maskToken(token), err) // Mask token in logs
return
}
// Log successful verification
userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
if !ok {
log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
} else {
log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
}
// Token is valid
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": true,
})
}
// maskToken masks sensitive parts of token for logging
func maskToken(token string) string {
if len(token) < 8 {
return "****"
}
return token[:4] + "****" + token[len(token)-4:]
}
// Routes returns all routes for the auth handler
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/login": h.Login,
"/verify-token": h.VerifyToken,
}
}

View file

@ -1,27 +0,0 @@
package handlers
import (
"encoding/json"
"net/http"
"github.com/jmoiron/sqlx"
)
type Handler struct {
DB *sqlx.DB
}
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
func WriteJSON(w http.ResponseWriter, data interface{}, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(data)
}
func WriteError(w http.ResponseWriter, message string, code int) {
WriteJSON(w, Response{Code: code, Message: message}, code)
}

View file

@ -1,158 +0,0 @@
package handlers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/golang-jwt/jwt"
"github.com/spf13/viper"
)
type LoginRequest struct {
Code string `json:"code"`
}
// 封装code2session接口返回数据
type LoginResponse struct {
OpenId string `json:"openid"`
SessionKey string `json:"session_key"`
UnionId string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
var wxClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
},
}
func getLoginResponse(code string) (*LoginResponse, error) {
appid := viper.GetString("wechat.appid")
secret := viper.GetString("wechat.secret")
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code)
resp, err := wxClient.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var loginResponse LoginResponse
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
return nil, err
}
if loginResponse.ErrCode != 0 {
switch loginResponse.ErrCode {
case 40029:
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
case 45011:
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
default:
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
}
}
return &loginResponse, nil
}
func generateJWT(openid string) (string, error) {
tokenTTL := viper.GetDuration("auth.ttl")
if tokenTTL <= 0 {
tokenTTL = 24 * time.Hour
}
secret := viper.GetString("auth.secret")
if secret == "" {
secret = "default_auth_secret_otpm"
}
claims := jwt.MapClaims{
"openid": openid,
"exp": time.Now().Add(tokenTTL).Unix(),
"iat": time.Now().Unix(),
"iss": viper.GetString("server.name"),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return "", err
}
return signedToken, nil
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
body, err := io.ReadAll(r.Body)
if err != nil {
WriteError(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
if err := json.Unmarshal(body, &req); err != nil {
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
return
}
loginResponse, err := getLoginResponse(req.Code)
if err != nil {
switch {
case err.Error() == "invalid code":
WriteError(w, "Invalid code", http.StatusUnauthorized)
case err.Error() == "api limit exceeded":
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
default:
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
}
return
}
// 插入或更新用户的openid和session_key
query := `
INSERT INTO users (openid, session_key)
VALUES ($1, $2)
ON CONFLICT (openid) DO UPDATE SET session_key = $2
RETURNING id;
`
var ID int
if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
WriteError(w, "Failed to log in user", http.StatusInternalServerError)
return
}
token, err := generateJWT(loginResponse.OpenId)
if err != nil {
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
return
}
data := map[string]interface{}{
"t": token,
"openid": loginResponse.OpenId,
}
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
}
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
userid := r.Context().Value("openid").(string)
token, err := generateJWT(userid)
if err != nil {
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
return
}
WriteJSON(w, Response{
Code: 0,
Message: "Token refreshed successfully",
Data: map[string]string{
"token": token,
},
}, http.StatusOK)
}

View file

@ -1,70 +0,0 @@
package handlers
import (
"encoding/json"
"log"
"net/http"
)
type OtpRequest struct {
OpenID string `json:"openid"`
Token *[]OTP `json:"token"`
}
type OTP struct {
Issuer string `json:"issuer"`
Remark string `json:"remark"`
Secret string `json:"secret"`
}
func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) {
var req OtpRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
return
}
if req.OpenID == "" {
WriteError(w, "OpenID is required", http.StatusBadRequest)
return
}
if req.Token == nil || len(*req.Token) == 0 {
WriteError(w, "Token is required", http.StatusBadRequest)
return
}
log.Printf("Saving OTP for user: %s token count:: %d", req.OpenID, len(*req.Token))
// 插入或更新 OTP 记录
query := `
INSERT INTO otp (openid, token)
VALUES ($1, $2)
ON CONFLICT (openid) DO UPDATE SET token = EXCLUDED.token
`
_, err := h.DB.Exec(query, req.OpenID, req.Token)
if err != nil {
WriteError(w, "Failed to update or create OTP", http.StatusInternalServerError)
return
}
WriteJSON(w, Response{Code: 0, Message: "Success"}, http.StatusOK)
}
func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
openid := r.URL.Query().Get("openid")
if openid == "" {
WriteError(w, "OpenID is required", http.StatusBadRequest)
return
}
var otp OtpRequest
err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid)
if err != nil {
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
return
}
WriteJSON(w, Response{Code: 0, Message: "Success", Data: otp}, http.StatusOK)
}

286
handlers/otp_handler.go Normal file
View file

@ -0,0 +1,286 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"otpm/api"
"otpm/middleware"
"otpm/models"
"otpm/services"
)
// OTPHandler handles OTP related requests
type OTPHandler struct {
otpService *services.OTPService
}
// NewOTPHandler creates a new OTPHandler
func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
return &OTPHandler{
otpService: otpService,
}
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
Secret string `json:"secret"`
Algorithm string `json:"algorithm"`
Digits int `json:"digits"`
Period int `json:"period"`
}
// CreateOTP handles OTP creation
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Limit request body size
r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("CreateOTP unauthorized attempt")
return
}
// Parse request
var req CreateOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
return
}
// Validate OTP parameters
if req.Secret == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Secret is required")
log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
return
}
// Validate algorithm
supportedAlgos := map[string]bool{
"SHA1": true,
"SHA256": true,
"SHA512": true,
}
if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
userID, req.Algorithm)
return
}
// Create OTP
otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
Name: req.Name,
Issuer: req.Issuer,
Secret: req.Secret,
Algorithm: req.Algorithm,
Digits: req.Digits,
Period: req.Period,
})
if err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
log.Printf("CreateOTP failed for user %s: %v", userID, err)
return
}
// Log successful creation (mask secret in logs)
log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
api.NewResponseWriter(w).WriteSuccess(otp)
}
// ListOTPs handles listing all OTPs for a user
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTPs
otps, err := h.otpService.ListOTPs(r.Context(), userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(otps)
}
// GetOTPCode handles generating OTP code
func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
otpID = strings.TrimSuffix(otpID, "/code")
// Validate OTP ID format
if len(otpID) != 36 { // Assuming UUID format
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid OTP ID format")
log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
return
}
// Rate limiting check could be added here
// (would require redis or similar rate limiter)
// Generate code
code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
return
}
// Log successful generation (without actual code)
log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
userID, otpID, time.Since(start), expiresIn)
api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
"code": code,
"expires_in": expiresIn,
})
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code"`
}
// VerifyOTP handles OTP code verification
func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
otpID = strings.TrimSuffix(otpID, "/verify")
// Parse request
var req VerifyOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
return
}
// Verify code
valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": valid,
})
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
Algorithm string `json:"algorithm"`
Digits int `json:"digits"`
Period int `json:"period"`
}
// UpdateOTP handles OTP update
func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
// Parse request
var req UpdateOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
return
}
// Update OTP
otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
Name: req.Name,
Issuer: req.Issuer,
Algorithm: req.Algorithm,
Digits: req.Digits,
Period: req.Period,
})
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(otp)
}
// DeleteOTP handles OTP deletion
func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
// Delete OTP
if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(map[string]string{
"message": "OTP deleted successfully",
})
}
// Routes returns all routes for the OTP handler
func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/otp": h.CreateOTP,
"/otp/": h.ListOTPs,
"/otp/{id}": h.UpdateOTP,
"/otp/{id}/code": h.GetOTPCode,
"/otp/{id}/verify": h.VerifyOTP,
}
}