package services import ( "context" "crypto/hmac" "crypto/sha1" "crypto/sha256" "crypto/sha512" "encoding/base32" "encoding/binary" "fmt" "hash" "log" "strings" "time" "otpm/models" "github.com/google/uuid" ) // OTPService handles OTP related operations type OTPService struct { otpRepo *models.OTPRepository } // NewOTPService creates a new OTPService func NewOTPService(otpRepo *models.OTPRepository) *OTPService { return &OTPService{ otpRepo: otpRepo, } } // CreateOTP creates a new OTP with performance monitoring and logging func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) { start := time.Now() // Validate input if err := s.validateOTPInput(input); err != nil { log.Printf("OTP validation failed for user %s: %v", userID, err) return nil, err } // Clean and standardize secret secret := cleanSecret(input.Secret) // Set defaults for optional fields algorithm := strings.ToUpper(input.Algorithm) if algorithm == "" { algorithm = "SHA1" } digits := input.Digits if digits == 0 { digits = 6 } period := input.Period if period == 0 { period = 30 } // Create OTP otp := &models.OTP{ ID: uuid.New().String(), UserID: userID, Name: input.Name, Issuer: input.Issuer, Secret: secret, Algorithm: algorithm, Digits: digits, Period: period, } if err := s.otpRepo.Create(ctx, otp); err != nil { log.Printf("Failed to create OTP for user %s: %v", userID, err) return nil, fmt.Errorf("failed to create OTP: %w", err) } // Log successful creation (without exposing secret) log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)", otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period) return otp, nil } // GetOTP gets an OTP by ID func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) { otp, err := s.otpRepo.FindByID(ctx, id, userID) if err != nil { return nil, fmt.Errorf("failed to get OTP: %w", err) } return otp, nil } // ListOTPs lists all OTPs for a user func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) { otps, err := s.otpRepo.FindAllByUserID(ctx, userID) if err != nil { return nil, fmt.Errorf("failed to list OTPs: %w", err) } return otps, nil } // UpdateOTP updates an OTP func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) { // Get existing OTP otp, err := s.otpRepo.FindByID(ctx, id, userID) if err != nil { return nil, fmt.Errorf("failed to get OTP: %w", err) } // Update fields if input.Name != "" { otp.Name = input.Name } if input.Issuer != "" { otp.Issuer = input.Issuer } if input.Algorithm != "" { otp.Algorithm = strings.ToUpper(input.Algorithm) } if input.Digits > 0 { otp.Digits = input.Digits } if input.Period > 0 { otp.Period = input.Period } // Validate updated OTP if err := s.validateOTPInput(models.OTPParams{ Name: otp.Name, Issuer: otp.Issuer, Secret: otp.Secret, Algorithm: otp.Algorithm, Digits: otp.Digits, Period: otp.Period, }); err != nil { return nil, err } if err := s.otpRepo.Update(ctx, otp); err != nil { return nil, fmt.Errorf("failed to update OTP: %w", err) } return otp, nil } // DeleteOTP deletes an OTP func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error { if err := s.otpRepo.Delete(ctx, id, userID); err != nil { return fmt.Errorf("failed to delete OTP: %w", err) } return nil } // GenerateCode generates a TOTP code with enhanced logging and error handling func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) { start := time.Now() otp, err := s.otpRepo.FindByID(ctx, id, userID) if err != nil { log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err) return "", 0, fmt.Errorf("failed to get OTP: %w", err) } // Get current time step now := time.Now().Unix() timeStep := now / int64(otp.Period) // Generate code code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits) if err != nil { log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err) return "", 0, fmt.Errorf("failed to generate code: %w", err) } // Calculate remaining seconds remainingSeconds := otp.Period - int(now%int64(otp.Period)) // Log successful generation (without actual code) log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)", id, userID, time.Since(start), remainingSeconds) return code, remainingSeconds, nil } // VerifyCode verifies a TOTP code with enhanced security and logging func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) { start := time.Now() // Basic input validation if len(code) == 0 { log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID) return false, fmt.Errorf("code is required") } otp, err := s.otpRepo.FindByID(ctx, id, userID) if err != nil { log.Printf("Failed to find OTP %s for user %s during verification: %v", id, userID, err) return false, fmt.Errorf("failed to get OTP: %w", err) } // Get current and adjacent time steps now := time.Now().Unix() timeSteps := []int64{ (now - int64(otp.Period)) / int64(otp.Period), now / int64(otp.Period), (now + int64(otp.Period)) / int64(otp.Period), } // Check code against all time steps for _, ts := range timeSteps { expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits) if err != nil { log.Printf("Code generation failed for time step %d: %v", ts, err) continue } if expectedCode == code { // Log successful verification log.Printf("Code verified successfully for OTP %s (user %s) in %v", id, userID, time.Since(start)) return true, nil } } // Log failed verification attempt log.Printf("Invalid code provided for OTP %s (user %s) in %v", id, userID, time.Since(start)) return false, nil } // validateOTPInput validates OTP input with detailed error messages func (s *OTPService) validateOTPInput(input models.OTPParams) error { if input.Name == "" { return fmt.Errorf("name is required") } if len(input.Name) > 100 { return fmt.Errorf("name is too long (maximum 100 characters)") } if input.Secret == "" { return fmt.Errorf("secret is required") } if !isValidBase32(input.Secret) { return fmt.Errorf("invalid secret format: must be a valid base32 string") } // Secret length check (after base32 decoding) secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "=")) if len(secretBytes) < 10 { return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)") } if input.Algorithm != "" { if !isValidAlgorithm(input.Algorithm) { return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm) } } if input.Digits != 0 { if input.Digits < 6 || input.Digits > 8 { return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits) } } if input.Period != 0 { if input.Period < 30 || input.Period > 60 { return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period) } } return nil } // Helper functions func cleanSecret(secret string) string { // Remove spaces and convert to upper case secret = strings.TrimSpace(strings.ToUpper(secret)) // Remove any padding characters return strings.TrimRight(secret, "=") } func isValidBase32(s string) bool { // Try to decode the secret _, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "=")) return err == nil } func isValidAlgorithm(algorithm string) bool { switch strings.ToUpper(algorithm) { case "SHA1", "SHA256", "SHA512": return true default: return false } } func getHasher(algorithm string, key []byte) (hash.Hash, error) { switch strings.ToUpper(algorithm) { case "SHA1": return hmac.New(sha1.New, key), nil case "SHA256": return hmac.New(sha256.New, key), nil case "SHA512": return hmac.New(sha512.New, key), nil default: return nil, fmt.Errorf("unsupported algorithm: %s", algorithm) } } func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) { // Decode secret secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "=")) if err != nil { return "", fmt.Errorf("invalid secret: %w", err) } // Get initialized HMAC hasher with secret hasher, err := getHasher(algorithm, secretBytes) if err != nil { return "", err } // Convert time step to bytes timeBytes := make([]byte, 8) binary.BigEndian.PutUint64(timeBytes, uint64(timeStep)) // Calculate HMAC hasher.Write(timeBytes) hash := hasher.Sum(nil) // Get offset offset := hash[len(hash)-1] & 0xf // Generate 4-byte code code := binary.BigEndian.Uint32(hash[offset : offset+4]) code = code & 0x7fffffff // Get the specified number of digits code = code % uint32(pow10(digits)) // Format code with leading zeros return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil } func pow10(n int) uint32 { result := uint32(1) for i := 0; i < n; i++ { result *= 10 } return result }