beta
This commit is contained in:
parent
a45ddf13d5
commit
bcd986e3f7
46 changed files with 6166 additions and 454 deletions
230
services/auth.go
Normal file
230
services/auth.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"otpm/config"
|
||||
"otpm/models"
|
||||
)
|
||||
|
||||
// WeChatCode2SessionResponse represents the response from WeChat code2session API
|
||||
type WeChatCode2SessionResponse struct {
|
||||
OpenID string `json:"openid"`
|
||||
SessionKey string `json:"session_key"`
|
||||
UnionID string `json:"unionid"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
// AuthService handles authentication related operations
|
||||
type AuthService struct {
|
||||
config *config.Config
|
||||
userRepo *models.UserRepository
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAuthService creates a new AuthService
|
||||
func NewAuthService(cfg *config.Config, userRepo *models.UserRepository) *AuthService {
|
||||
return &AuthService{
|
||||
config: cfg,
|
||||
userRepo: userRepo,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoginWithWeChatCode handles WeChat login
|
||||
func (s *AuthService) LoginWithWeChatCode(ctx context.Context, code string) (string, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Get OpenID and SessionKey from WeChat
|
||||
sessionInfo, err := s.getWeChatSession(code)
|
||||
if err != nil {
|
||||
log.Printf("WeChat login failed for code %s: %v", maskCode(code), err)
|
||||
return "", fmt.Errorf("failed to get WeChat session: %w", err)
|
||||
}
|
||||
log.Printf("WeChat session obtained for code %s (took %v)",
|
||||
maskCode(code), time.Since(start))
|
||||
|
||||
// Find or create user
|
||||
user, err := s.userRepo.FindByOpenID(ctx, sessionInfo.OpenID)
|
||||
if err != nil {
|
||||
log.Printf("User lookup failed for OpenID %s: %v",
|
||||
maskOpenID(sessionInfo.OpenID), err)
|
||||
return "", fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
// Create new user
|
||||
user = &models.User{
|
||||
ID: uuid.New().String(),
|
||||
OpenID: sessionInfo.OpenID,
|
||||
SessionKey: sessionInfo.SessionKey,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
log.Printf("User creation failed for OpenID %s: %v",
|
||||
maskOpenID(sessionInfo.OpenID), err)
|
||||
return "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
log.Printf("New user created with ID %s for OpenID %s",
|
||||
user.ID, maskOpenID(sessionInfo.OpenID))
|
||||
} else {
|
||||
// Update session key
|
||||
user.SessionKey = sessionInfo.SessionKey
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("User update failed for ID %s: %v", user.ID, err)
|
||||
return "", fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
log.Printf("User %s session key updated", user.ID)
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := s.generateToken(user)
|
||||
if err != nil {
|
||||
log.Printf("Token generation failed for user %s: %v", user.ID, err)
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("WeChat login completed for user %s (total time %v)",
|
||||
user.ID, time.Since(start))
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// maskCode masks sensitive parts of WeChat code for logging
|
||||
func maskCode(code string) string {
|
||||
if len(code) < 8 {
|
||||
return "****"
|
||||
}
|
||||
return code[:2] + "****" + code[len(code)-2:]
|
||||
}
|
||||
|
||||
// maskOpenID masks sensitive parts of OpenID for logging
|
||||
func maskOpenID(openID string) string {
|
||||
if len(openID) < 8 {
|
||||
return "****"
|
||||
}
|
||||
return openID[:2] + "****" + openID[len(openID)-2:]
|
||||
}
|
||||
|
||||
// getWeChatSession calls WeChat's code2session API
|
||||
func (s *AuthService) getWeChatSession(code string) (*WeChatCode2SessionResponse, error) {
|
||||
url := fmt.Sprintf(
|
||||
"https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
||||
s.config.WeChat.AppID,
|
||||
s.config.WeChat.AppSecret,
|
||||
code,
|
||||
)
|
||||
|
||||
resp, err := s.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call WeChat API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result WeChatCode2SessionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode WeChat response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("WeChat API error: %d - %s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// generateToken generates a JWT token for a user
|
||||
func (s *AuthService) generateToken(user *models.User) (string, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"exp": now.Add(s.config.JWT.ExpireDelta).Unix(),
|
||||
"iat": now.Unix(),
|
||||
"iss": s.config.JWT.Issuer,
|
||||
"aud": s.config.JWT.Audience,
|
||||
"token_id": uuid.New().String(), // Unique token ID for tracking
|
||||
}
|
||||
|
||||
// Use stronger signing method
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
||||
signedToken, err := token.SignedString([]byte(s.config.JWT.Secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Token generated for user %s (expires at %v)",
|
||||
user.ID, now.Add(s.config.JWT.ExpireDelta))
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token with additional checks
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*jwt.Token, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
switch {
|
||||
case ve.Errors&jwt.ValidationErrorMalformed != 0:
|
||||
return nil, fmt.Errorf("malformed token")
|
||||
case ve.Errors&jwt.ValidationErrorExpired != 0:
|
||||
return nil, fmt.Errorf("token expired")
|
||||
case ve.Errors&jwt.ValidationErrorNotValidYet != 0:
|
||||
return nil, fmt.Errorf("token not active yet")
|
||||
default:
|
||||
return nil, fmt.Errorf("token validation error: %w", err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// Additional claims validation
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// Check issuer
|
||||
if iss, ok := claims["iss"].(string); !ok || iss != s.config.JWT.Issuer {
|
||||
return nil, fmt.Errorf("invalid token issuer")
|
||||
}
|
||||
// Check audience
|
||||
if aud, ok := claims["aud"].(string); !ok || aud != s.config.JWT.Audience {
|
||||
return nil, fmt.Errorf("invalid token audience")
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetUserFromToken gets user information from a JWT token
|
||||
func (s *AuthService) GetUserFromToken(ctx context.Context, token *jwt.Token) (*models.User, error) {
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("user_id not found in token")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.FindByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
358
services/otp.go
Normal file
358
services/otp.go
Normal file
|
@ -0,0 +1,358 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"otpm/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OTPService handles OTP related operations
|
||||
type OTPService struct {
|
||||
otpRepo *models.OTPRepository
|
||||
}
|
||||
|
||||
// NewOTPService creates a new OTPService
|
||||
func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
|
||||
return &OTPService{
|
||||
otpRepo: otpRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOTP creates a new OTP with performance monitoring and logging
|
||||
func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Validate input
|
||||
if err := s.validateOTPInput(input); err != nil {
|
||||
log.Printf("OTP validation failed for user %s: %v", userID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Clean and standardize secret
|
||||
secret := cleanSecret(input.Secret)
|
||||
|
||||
// Set defaults for optional fields
|
||||
algorithm := strings.ToUpper(input.Algorithm)
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA1"
|
||||
}
|
||||
|
||||
digits := input.Digits
|
||||
if digits == 0 {
|
||||
digits = 6
|
||||
}
|
||||
|
||||
period := input.Period
|
||||
if period == 0 {
|
||||
period = 30
|
||||
}
|
||||
|
||||
// Create OTP
|
||||
otp := &models.OTP{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
Name: input.Name,
|
||||
Issuer: input.Issuer,
|
||||
Secret: secret,
|
||||
Algorithm: algorithm,
|
||||
Digits: digits,
|
||||
Period: period,
|
||||
}
|
||||
|
||||
if err := s.otpRepo.Create(ctx, otp); err != nil {
|
||||
log.Printf("Failed to create OTP for user %s: %v", userID, err)
|
||||
return nil, fmt.Errorf("failed to create OTP: %w", err)
|
||||
}
|
||||
|
||||
// Log successful creation (without exposing secret)
|
||||
log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
|
||||
otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
|
||||
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// GetOTP gets an OTP by ID
|
||||
func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// ListOTPs lists all OTPs for a user
|
||||
func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
|
||||
otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list OTPs: %w", err)
|
||||
}
|
||||
return otps, nil
|
||||
}
|
||||
|
||||
// UpdateOTP updates an OTP
|
||||
func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
|
||||
// Get existing OTP
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if input.Name != "" {
|
||||
otp.Name = input.Name
|
||||
}
|
||||
if input.Issuer != "" {
|
||||
otp.Issuer = input.Issuer
|
||||
}
|
||||
if input.Algorithm != "" {
|
||||
otp.Algorithm = strings.ToUpper(input.Algorithm)
|
||||
}
|
||||
if input.Digits > 0 {
|
||||
otp.Digits = input.Digits
|
||||
}
|
||||
if input.Period > 0 {
|
||||
otp.Period = input.Period
|
||||
}
|
||||
|
||||
// Validate updated OTP
|
||||
if err := s.validateOTPInput(models.OTPParams{
|
||||
Name: otp.Name,
|
||||
Issuer: otp.Issuer,
|
||||
Secret: otp.Secret,
|
||||
Algorithm: otp.Algorithm,
|
||||
Digits: otp.Digits,
|
||||
Period: otp.Period,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.otpRepo.Update(ctx, otp); err != nil {
|
||||
return nil, fmt.Errorf("failed to update OTP: %w", err)
|
||||
}
|
||||
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// DeleteOTP deletes an OTP
|
||||
func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
|
||||
if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
|
||||
return fmt.Errorf("failed to delete OTP: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode generates a TOTP code with enhanced logging and error handling
|
||||
func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
|
||||
start := time.Now()
|
||||
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
|
||||
return "", 0, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
|
||||
// Get current time step
|
||||
now := time.Now().Unix()
|
||||
timeStep := now / int64(otp.Period)
|
||||
|
||||
// Generate code
|
||||
code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
|
||||
if err != nil {
|
||||
log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
|
||||
return "", 0, fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
|
||||
// Calculate remaining seconds
|
||||
remainingSeconds := otp.Period - int(now%int64(otp.Period))
|
||||
|
||||
// Log successful generation (without actual code)
|
||||
log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
|
||||
id, userID, time.Since(start), remainingSeconds)
|
||||
|
||||
return code, remainingSeconds, nil
|
||||
}
|
||||
|
||||
// VerifyCode verifies a TOTP code with enhanced security and logging
|
||||
func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Basic input validation
|
||||
if len(code) == 0 {
|
||||
log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
|
||||
return false, fmt.Errorf("code is required")
|
||||
}
|
||||
|
||||
otp, err := s.otpRepo.FindByID(ctx, id, userID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to find OTP %s for user %s during verification: %v",
|
||||
id, userID, err)
|
||||
return false, fmt.Errorf("failed to get OTP: %w", err)
|
||||
}
|
||||
|
||||
// Get current and adjacent time steps
|
||||
now := time.Now().Unix()
|
||||
timeSteps := []int64{
|
||||
(now - int64(otp.Period)) / int64(otp.Period),
|
||||
now / int64(otp.Period),
|
||||
(now + int64(otp.Period)) / int64(otp.Period),
|
||||
}
|
||||
|
||||
// Check code against all time steps
|
||||
for _, ts := range timeSteps {
|
||||
expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
|
||||
if err != nil {
|
||||
log.Printf("Code generation failed for time step %d: %v", ts, err)
|
||||
continue
|
||||
}
|
||||
if expectedCode == code {
|
||||
// Log successful verification
|
||||
log.Printf("Code verified successfully for OTP %s (user %s) in %v",
|
||||
id, userID, time.Since(start))
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Log failed verification attempt
|
||||
log.Printf("Invalid code provided for OTP %s (user %s) in %v",
|
||||
id, userID, time.Since(start))
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// validateOTPInput validates OTP input with detailed error messages
|
||||
func (s *OTPService) validateOTPInput(input models.OTPParams) error {
|
||||
if input.Name == "" {
|
||||
return fmt.Errorf("name is required")
|
||||
}
|
||||
|
||||
if len(input.Name) > 100 {
|
||||
return fmt.Errorf("name is too long (maximum 100 characters)")
|
||||
}
|
||||
|
||||
if input.Secret == "" {
|
||||
return fmt.Errorf("secret is required")
|
||||
}
|
||||
|
||||
if !isValidBase32(input.Secret) {
|
||||
return fmt.Errorf("invalid secret format: must be a valid base32 string")
|
||||
}
|
||||
|
||||
// Secret length check (after base32 decoding)
|
||||
secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
|
||||
if len(secretBytes) < 10 {
|
||||
return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
|
||||
}
|
||||
|
||||
if input.Algorithm != "" {
|
||||
if !isValidAlgorithm(input.Algorithm) {
|
||||
return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
if input.Digits != 0 {
|
||||
if input.Digits < 6 || input.Digits > 8 {
|
||||
return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
|
||||
}
|
||||
}
|
||||
|
||||
if input.Period != 0 {
|
||||
if input.Period < 30 || input.Period > 60 {
|
||||
return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func cleanSecret(secret string) string {
|
||||
// Remove spaces and convert to upper case
|
||||
secret = strings.TrimSpace(strings.ToUpper(secret))
|
||||
// Remove any padding characters
|
||||
return strings.TrimRight(secret, "=")
|
||||
}
|
||||
|
||||
func isValidBase32(s string) bool {
|
||||
// Try to decode the secret
|
||||
_, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func isValidAlgorithm(algorithm string) bool {
|
||||
switch strings.ToUpper(algorithm) {
|
||||
case "SHA1", "SHA256", "SHA512":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getHasher(algorithm string, key []byte) (hash.Hash, error) {
|
||||
switch strings.ToUpper(algorithm) {
|
||||
case "SHA1":
|
||||
return hmac.New(sha1.New, key), nil
|
||||
case "SHA256":
|
||||
return hmac.New(sha256.New, key), nil
|
||||
case "SHA512":
|
||||
return hmac.New(sha512.New, key), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
|
||||
// Decode secret
|
||||
secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid secret: %w", err)
|
||||
}
|
||||
|
||||
// Get initialized HMAC hasher with secret
|
||||
hasher, err := getHasher(algorithm, secretBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Convert time step to bytes
|
||||
timeBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
|
||||
|
||||
// Calculate HMAC
|
||||
hasher.Write(timeBytes)
|
||||
hash := hasher.Sum(nil)
|
||||
|
||||
// Get offset
|
||||
offset := hash[len(hash)-1] & 0xf
|
||||
|
||||
// Generate 4-byte code
|
||||
code := binary.BigEndian.Uint32(hash[offset : offset+4])
|
||||
code = code & 0x7fffffff
|
||||
|
||||
// Get the specified number of digits
|
||||
code = code % uint32(pow10(digits))
|
||||
|
||||
// Format code with leading zeros
|
||||
return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
|
||||
}
|
||||
|
||||
func pow10(n int) uint32 {
|
||||
result := uint32(1)
|
||||
for i := 0; i < n; i++ {
|
||||
result *= 10
|
||||
}
|
||||
return result
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue