otpm/services/otp.go
“xHuPo” bcd986e3f7 beta
2025-05-23 18:57:11 +08:00

358 lines
9.1 KiB
Go

package services
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"log"
"strings"
"time"
"otpm/models"
"github.com/google/uuid"
)
// OTPService handles OTP related operations
type OTPService struct {
otpRepo *models.OTPRepository
}
// NewOTPService creates a new OTPService
func NewOTPService(otpRepo *models.OTPRepository) *OTPService {
return &OTPService{
otpRepo: otpRepo,
}
}
// CreateOTP creates a new OTP with performance monitoring and logging
func (s *OTPService) CreateOTP(ctx context.Context, userID string, input models.OTPParams) (*models.OTP, error) {
start := time.Now()
// Validate input
if err := s.validateOTPInput(input); err != nil {
log.Printf("OTP validation failed for user %s: %v", userID, err)
return nil, err
}
// Clean and standardize secret
secret := cleanSecret(input.Secret)
// Set defaults for optional fields
algorithm := strings.ToUpper(input.Algorithm)
if algorithm == "" {
algorithm = "SHA1"
}
digits := input.Digits
if digits == 0 {
digits = 6
}
period := input.Period
if period == 0 {
period = 30
}
// Create OTP
otp := &models.OTP{
ID: uuid.New().String(),
UserID: userID,
Name: input.Name,
Issuer: input.Issuer,
Secret: secret,
Algorithm: algorithm,
Digits: digits,
Period: period,
}
if err := s.otpRepo.Create(ctx, otp); err != nil {
log.Printf("Failed to create OTP for user %s: %v", userID, err)
return nil, fmt.Errorf("failed to create OTP: %w", err)
}
// Log successful creation (without exposing secret)
log.Printf("Created OTP %s for user %s in %v (name=%s, issuer=%s, algo=%s, digits=%d, period=%d)",
otp.ID, userID, time.Since(start), otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period)
return otp, nil
}
// GetOTP gets an OTP by ID
func (s *OTPService) GetOTP(ctx context.Context, id, userID string) (*models.OTP, error) {
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
return otp, nil
}
// ListOTPs lists all OTPs for a user
func (s *OTPService) ListOTPs(ctx context.Context, userID string) ([]*models.OTP, error) {
otps, err := s.otpRepo.FindAllByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to list OTPs: %w", err)
}
return otps, nil
}
// UpdateOTP updates an OTP
func (s *OTPService) UpdateOTP(ctx context.Context, id, userID string, input models.OTPParams) (*models.OTP, error) {
// Get existing OTP
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
return nil, fmt.Errorf("failed to get OTP: %w", err)
}
// Update fields
if input.Name != "" {
otp.Name = input.Name
}
if input.Issuer != "" {
otp.Issuer = input.Issuer
}
if input.Algorithm != "" {
otp.Algorithm = strings.ToUpper(input.Algorithm)
}
if input.Digits > 0 {
otp.Digits = input.Digits
}
if input.Period > 0 {
otp.Period = input.Period
}
// Validate updated OTP
if err := s.validateOTPInput(models.OTPParams{
Name: otp.Name,
Issuer: otp.Issuer,
Secret: otp.Secret,
Algorithm: otp.Algorithm,
Digits: otp.Digits,
Period: otp.Period,
}); err != nil {
return nil, err
}
if err := s.otpRepo.Update(ctx, otp); err != nil {
return nil, fmt.Errorf("failed to update OTP: %w", err)
}
return otp, nil
}
// DeleteOTP deletes an OTP
func (s *OTPService) DeleteOTP(ctx context.Context, id, userID string) error {
if err := s.otpRepo.Delete(ctx, id, userID); err != nil {
return fmt.Errorf("failed to delete OTP: %w", err)
}
return nil
}
// GenerateCode generates a TOTP code with enhanced logging and error handling
func (s *OTPService) GenerateCode(ctx context.Context, id, userID string) (string, int, error) {
start := time.Now()
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s: %v", id, userID, err)
return "", 0, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current time step
now := time.Now().Unix()
timeStep := now / int64(otp.Period)
// Generate code
code, err := generateTOTP(otp.Secret, timeStep, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Failed to generate code for OTP %s (user %s): %v", id, userID, err)
return "", 0, fmt.Errorf("failed to generate code: %w", err)
}
// Calculate remaining seconds
remainingSeconds := otp.Period - int(now%int64(otp.Period))
// Log successful generation (without actual code)
log.Printf("Generated code for OTP %s (user %s) in %v (expires in %ds)",
id, userID, time.Since(start), remainingSeconds)
return code, remainingSeconds, nil
}
// VerifyCode verifies a TOTP code with enhanced security and logging
func (s *OTPService) VerifyCode(ctx context.Context, id, userID, code string) (bool, error) {
start := time.Now()
// Basic input validation
if len(code) == 0 {
log.Printf("Empty code verification attempt for OTP %s (user %s)", id, userID)
return false, fmt.Errorf("code is required")
}
otp, err := s.otpRepo.FindByID(ctx, id, userID)
if err != nil {
log.Printf("Failed to find OTP %s for user %s during verification: %v",
id, userID, err)
return false, fmt.Errorf("failed to get OTP: %w", err)
}
// Get current and adjacent time steps
now := time.Now().Unix()
timeSteps := []int64{
(now - int64(otp.Period)) / int64(otp.Period),
now / int64(otp.Period),
(now + int64(otp.Period)) / int64(otp.Period),
}
// Check code against all time steps
for _, ts := range timeSteps {
expectedCode, err := generateTOTP(otp.Secret, ts, otp.Algorithm, otp.Digits)
if err != nil {
log.Printf("Code generation failed for time step %d: %v", ts, err)
continue
}
if expectedCode == code {
// Log successful verification
log.Printf("Code verified successfully for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return true, nil
}
}
// Log failed verification attempt
log.Printf("Invalid code provided for OTP %s (user %s) in %v",
id, userID, time.Since(start))
return false, nil
}
// validateOTPInput validates OTP input with detailed error messages
func (s *OTPService) validateOTPInput(input models.OTPParams) error {
if input.Name == "" {
return fmt.Errorf("name is required")
}
if len(input.Name) > 100 {
return fmt.Errorf("name is too long (maximum 100 characters)")
}
if input.Secret == "" {
return fmt.Errorf("secret is required")
}
if !isValidBase32(input.Secret) {
return fmt.Errorf("invalid secret format: must be a valid base32 string")
}
// Secret length check (after base32 decoding)
secretBytes, _ := base32.StdEncoding.DecodeString(strings.TrimRight(input.Secret, "="))
if len(secretBytes) < 10 {
return fmt.Errorf("secret is too short (minimum 10 bytes after decoding)")
}
if input.Algorithm != "" {
if !isValidAlgorithm(input.Algorithm) {
return fmt.Errorf("invalid algorithm: %s (supported: SHA1, SHA256, SHA512)", input.Algorithm)
}
}
if input.Digits != 0 {
if input.Digits < 6 || input.Digits > 8 {
return fmt.Errorf("digits must be between 6 and 8 (got %d)", input.Digits)
}
}
if input.Period != 0 {
if input.Period < 30 || input.Period > 60 {
return fmt.Errorf("period must be between 30 and 60 seconds (got %d)", input.Period)
}
}
return nil
}
// Helper functions
func cleanSecret(secret string) string {
// Remove spaces and convert to upper case
secret = strings.TrimSpace(strings.ToUpper(secret))
// Remove any padding characters
return strings.TrimRight(secret, "=")
}
func isValidBase32(s string) bool {
// Try to decode the secret
_, err := base32.StdEncoding.DecodeString(strings.TrimRight(s, "="))
return err == nil
}
func isValidAlgorithm(algorithm string) bool {
switch strings.ToUpper(algorithm) {
case "SHA1", "SHA256", "SHA512":
return true
default:
return false
}
}
func getHasher(algorithm string, key []byte) (hash.Hash, error) {
switch strings.ToUpper(algorithm) {
case "SHA1":
return hmac.New(sha1.New, key), nil
case "SHA256":
return hmac.New(sha256.New, key), nil
case "SHA512":
return hmac.New(sha512.New, key), nil
default:
return nil, fmt.Errorf("unsupported algorithm: %s", algorithm)
}
}
func generateTOTP(secret string, timeStep int64, algorithm string, digits int) (string, error) {
// Decode secret
secretBytes, err := base32.StdEncoding.DecodeString(strings.TrimRight(secret, "="))
if err != nil {
return "", fmt.Errorf("invalid secret: %w", err)
}
// Get initialized HMAC hasher with secret
hasher, err := getHasher(algorithm, secretBytes)
if err != nil {
return "", err
}
// Convert time step to bytes
timeBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timeBytes, uint64(timeStep))
// Calculate HMAC
hasher.Write(timeBytes)
hash := hasher.Sum(nil)
// Get offset
offset := hash[len(hash)-1] & 0xf
// Generate 4-byte code
code := binary.BigEndian.Uint32(hash[offset : offset+4])
code = code & 0x7fffffff
// Get the specified number of digits
code = code % uint32(pow10(digits))
// Format code with leading zeros
return fmt.Sprintf(fmt.Sprintf("%%0%dd", digits), code), nil
}
func pow10(n int) uint32 {
result := uint32(1)
for i := 0; i < n; i++ {
result *= 10
}
return result
}