230 lines
6.6 KiB
Go
230 lines
6.6 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/google/uuid"
|
|
|
|
"otpm/config"
|
|
"otpm/models"
|
|
)
|
|
|
|
// WeChatCode2SessionResponse represents the response from WeChat code2session API
|
|
type WeChatCode2SessionResponse struct {
|
|
OpenID string `json:"openid"`
|
|
SessionKey string `json:"session_key"`
|
|
UnionID string `json:"unionid"`
|
|
ErrCode int `json:"errcode"`
|
|
ErrMsg string `json:"errmsg"`
|
|
}
|
|
|
|
// AuthService handles authentication related operations
|
|
type AuthService struct {
|
|
config *config.Config
|
|
userRepo *models.UserRepository
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewAuthService creates a new AuthService
|
|
func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
|
|
return &AuthService{
|
|
config: cfg,
|
|
userRepo: userRepo,
|
|
httpClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// LoginWithWeChatCode handles WeChat login
|
|
func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
|
|
start := time.Now()
|
|
|
|
// Get OpenID and SessionKey from WeChat
|
|
sessionInfo, err := s.getWeChatSession(code)
|
|
if err != nil {
|
|
log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
|
|
return "", fmt.Errorf("failed to get WeChat session: %w", err)
|
|
}
|
|
log.Printf("WeChat session obtained for code %s (took %v)",
|
|
maskCode(code), time.Since(start))
|
|
|
|
// Find or create user
|
|
user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
|
|
if err != nil {
|
|
log.Printf("User lookup failed for OpenID %s: %v",
|
|
maskOpenID(sessionInfo.OpenID), err)
|
|
return "", fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
// Create new user
|
|
user = &models.User{
|
|
ID: uuid.New().String(),
|
|
OpenID: sessionInfo.OpenID,
|
|
SessionKey: sessionInfo.SessionKey,
|
|
}
|
|
if err := s.userRepo.Create(ctx, user); err != nil {
|
|
log.Printf("User creation failed for OpenID %s: %v",
|
|
maskOpenID(sessionInfo.OpenID), err)
|
|
return "", fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
log.Printf("New user created with ID %s for OpenID %s",
|
|
user.ID, maskOpenID(sessionInfo.OpenID))
|
|
} else {
|
|
// Update session key
|
|
user.SessionKey = sessionInfo.SessionKey
|
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
|
log.Printf("User update failed for ID %s: %v", user.ID, err)
|
|
return "", fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
log.Printf("User %s session key updated", user.ID)
|
|
}
|
|
|
|
// Generate JWT token
|
|
token, err := s.generateToken(user)
|
|
if err != nil {
|
|
log.Printf("Token generation failed for user %s: %v", user.ID, err)
|
|
return "", fmt.Errorf("failed to generate token: %w", err)
|
|
}
|
|
|
|
log.Printf("WeChat login completed for user %s (total time %v)",
|
|
user.ID, time.Since(start))
|
|
return token, nil
|
|
}
|
|
|
|
// maskCode masks sensitive parts of WeChat code for logging
|
|
func maskCode(code string) string {
|
|
if len(code) < 8 {
|
|
return "****"
|
|
}
|
|
return code[:2] + "****" + code[len(code)-2:]
|
|
}
|
|
|
|
// maskOpenID masks sensitive parts of OpenID for logging
|
|
func maskOpenID(openID string) string {
|
|
if len(openID) < 8 {
|
|
return "****"
|
|
}
|
|
return openID[:2] + "****" + openID[len(openID)-2:]
|
|
}
|
|
|
|
// getWeChatSession calls WeChat's code2session API
|
|
func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
|
|
url := fmt.Sprintf(
|
|
"https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
|
s.config.WeChat.AppID,
|
|
s.config.WeChat.AppSecret,
|
|
code,
|
|
)
|
|
|
|
resp, err := s.httpClient.Get(url)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to call WeChat API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var result WeChatCode2SessionResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
|
|
}
|
|
|
|
if result.ErrCode != 0 {
|
|
return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
|
|
}
|
|
|
|
return &result, nil
|
|
}
|
|
|
|
// generateToken generates a JWT token for a user
|
|
func (s *AuthService) generateToken(user *models.User) (string, error) {
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"user_id": user.ID,
|
|
"exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
|
|
"iat": now.Unix(),
|
|
"iss": s.config.JWT.Issuer,
|
|
"aud": s.config.JWT.Audience,
|
|
"token_id": uuid.New().String(), // Unique token ID for tracking
|
|
}
|
|
|
|
// Use stronger signing method
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
|
signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
|
}
|
|
|
|
log.Printf("Token generated for user %s (expires at %v)",
|
|
user.ID, now.Add(s.config.JWT.ExpireDelta))
|
|
return signedToken, nil
|
|
}
|
|
|
|
// ValidateToken validates a JWT token with additional checks
|
|
func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
// Verify signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(s.config.JWT.Secret), nil
|
|
})
|
|
|
|
if err != nil {
|
|
if ve, ok := err.(*jwt.ValidationError); ok {
|
|
switch {
|
|
case ve.Errors&jwt.ValidationErrorMalformed != 0:
|
|
return nil, fmt.Errorf("malformed token")
|
|
case ve.Errors&jwt.ValidationErrorExpired != 0:
|
|
return nil, fmt.Errorf("token expired")
|
|
case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
|
|
return nil, fmt.Errorf("token not active yet")
|
|
default:
|
|
return nil, fmt.Errorf("token validation error: %w", err)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
|
}
|
|
|
|
// Additional claims validation
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
// Check issuer
|
|
if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
|
|
return nil, fmt.Errorf("invalid token issuer")
|
|
}
|
|
// Check audience
|
|
if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
|
|
return nil, fmt.Errorf("invalid token audience")
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("invalid token claims")
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// GetUserFromToken gets user information from a JWT token
|
|
func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid token claims")
|
|
}
|
|
|
|
userID, ok := claims["user_id"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("user_id not found in token")
|
|
}
|
|
|
|
user, err := s.userRepo.FindByID(ctx, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|