package services import ( "context" "encoding/json" "fmt" "log" "net/http" "time" "github.com/golang-jwt/jwt" "github.com/google/uuid" "otpm/config" "otpm/models" ) // WeChatCode2SessionResponse represents the response from WeChat code2session API type WeChatCode2SessionResponse struct { OpenID string `json:"openid"` SessionKey string `json:"session_key"` UnionID string `json:"unionid"` ErrCode int `json:"errcode"` ErrMsg string `json:"errmsg"` } // AuthService handles authentication related operations type AuthService struct { config *config.Config userRepo *models.UserRepository httpClient *http.Client } // NewAuthService creates a new AuthService func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService { return &AuthService{ config: cfg, userRepo: userRepo, httpClient: &http.Client{ Timeout: 10 * time.Second, }, } } // LoginWithWeChatCode handles WeChat login func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) { start := time.Now() // Get OpenID and SessionKey from WeChat sessionInfo, err := s.getWeChatSession(code) if err != nil { log.Printf("WeChat login failed for code %s: %v", maskCode(code), err) return "", fmt.Errorf("failed to get WeChat session: %w", err) } log.Printf("WeChat session obtained for code %s (took %v)", maskCode(code), time.Since(start)) // Find or create user user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID) if err != nil { log.Printf("User lookup failed for OpenID %s: %v", maskOpenID(sessionInfo.OpenID), err) return "", fmt.Errorf("failed to find user: %w", err) } if user == nil { // Create new user user = &models.User{ ID: uuid.New().String(), OpenID: sessionInfo.OpenID, SessionKey: sessionInfo.SessionKey, } if err := s.userRepo.Create(ctx, user); err != nil { log.Printf("User creation failed for OpenID %s: %v", maskOpenID(sessionInfo.OpenID), err) return "", fmt.Errorf("failed to create user: %w", err) } log.Printf("New user created with ID %s for OpenID %s", user.ID, maskOpenID(sessionInfo.OpenID)) } else { // Update session key user.SessionKey = sessionInfo.SessionKey if err := s.userRepo.Update(ctx, user); err != nil { log.Printf("User update failed for ID %s: %v", user.ID, err) return "", fmt.Errorf("failed to update user: %w", err) } log.Printf("User %s session key updated", user.ID) } // Generate JWT token token, err := s.generateToken(user) if err != nil { log.Printf("Token generation failed for user %s: %v", user.ID, err) return "", fmt.Errorf("failed to generate token: %w", err) } log.Printf("WeChat login completed for user %s (total time %v)", user.ID, time.Since(start)) return token, nil } // maskCode masks sensitive parts of WeChat code for logging func maskCode(code string) string { if len(code) < 8 { return "****" } return code[:2] + "****" + code[len(code)-2:] } // maskOpenID masks sensitive parts of OpenID for logging func maskOpenID(openID string) string { if len(openID) < 8 { return "****" } return openID[:2] + "****" + openID[len(openID)-2:] } // getWeChatSession calls WeChat's code2session API func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) { url := fmt.Sprintf( "https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", s.config.WeChat.AppID, s.config.WeChat.AppSecret, code, ) resp, err := s.httpClient.Get(url) if err != nil { return nil, fmt.Errorf("failed to call WeChat API: %w", err) } defer resp.Body.Close() var result WeChatCode2SessionResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode WeChat response: %w", err) } if result.ErrCode != 0 { return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg) } return &result, nil } // generateToken generates a JWT token for a user func (s *AuthService) generateToken(user *models.User) (string, error) { now := time.Now() claims := jwt.MapClaims{ "user_id": user.ID, "exp": now.Add(s.config.JWT.ExpireDelta).Unix(), "iat": now.Unix(), "iss": s.config.JWT.Issuer, "aud": s.config.JWT.Audience, "token_id": uuid.New().String(), // Unique token ID for tracking } // Use stronger signing method token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) signedToken, err := token.SignedString([]byte(s.config.JWT.Secret)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } log.Printf("Token generated for user %s (expires at %v)", user.ID, now.Add(s.config.JWT.ExpireDelta)) return signedToken, nil } // ValidateToken validates a JWT token with additional checks func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Verify signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(s.config.JWT.Secret), nil }) if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { switch { case ve.Errors&jwt.ValidationErrorMalformed != 0: return nil, fmt.Errorf("malformed token") case ve.Errors&jwt.ValidationErrorExpired != 0: return nil, fmt.Errorf("token expired") case ve.Errors&jwt.ValidationErrorNotValidYet != 0: return nil, fmt.Errorf("token not active yet") default: return nil, fmt.Errorf("token validation error: %w", err) } } return nil, fmt.Errorf("failed to parse token: %w", err) } // Additional claims validation if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { // Check issuer if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer { return nil, fmt.Errorf("invalid token issuer") } // Check audience if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience { return nil, fmt.Errorf("invalid token audience") } } else { return nil, fmt.Errorf("invalid token claims") } return token, nil } // GetUserFromToken gets user information from a JWT token func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) { claims, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, fmt.Errorf("invalid token claims") } userID, ok := claims["user_id"].(string) if !ok { return nil, fmt.Errorf("user_id not found in token") } user, err := s.userRepo.FindByID(ctx, userID) if err != nil { return nil, fmt.Errorf("failed to find user: %w", err) } return user, nil }