package models import ( "context" "database/sql" "fmt" "time" "github.com/jmoiron/sqlx" ) // OTP represents a TOTP configuration type OTP struct { ID string `db:"id" json:"id"` UserID string `db:"user_id" json:"user_id"` Name string `db:"name" json:"name"` Issuer string `db:"issuer" json:"issuer"` Secret string `db:"secret" json:"-"` // Never expose secret in JSON Algorithm string `db:"algorithm" json:"algorithm"` Digits int `db:"digits" json:"digits"` Period int `db:"period" json:"period"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } // OTPParams represents common OTP parameters used in creation and update type OTPParams struct { Name string Issuer string Secret string Algorithm string Digits int Period int } // OTPRepository handles OTP data operations type OTPRepository struct { db *sqlx.DB } // NewOTPRepository creates a new OTPRepository func NewOTPRepository(db *sqlx.DB) *OTPRepository { return &OTPRepository{db: db} } // FindByID finds an OTP by ID and user ID func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) { var otp OTP query := `SELECT * FROM otps WHERE id = ? AND user_id = ?` err := r.db.GetContext(ctx, &otp, query, id, userID) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("otp not found: %w", err) } return nil, fmt.Errorf("failed to find otp: %w", err) } return &otp, nil } // FindAllByUserID finds all OTPs for a user func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) { var otps []*OTP query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC` err := r.db.SelectContext(ctx, &otps, query, userID) if err != nil { return nil, fmt.Errorf("failed to find otps: %w", err) } return otps, nil } // Create creates a new OTP func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error { query := ` INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` now := time.Now() otp.CreatedAt = now otp.UpdatedAt = now _, err := r.db.ExecContext( ctx, query, otp.ID, otp.UserID, otp.Name, otp.Issuer, otp.Secret, otp.Algorithm, otp.Digits, otp.Period, otp.CreatedAt, otp.UpdatedAt, ) if err != nil { return fmt.Errorf("failed to create otp: %w", err) } return nil } // Update updates an existing OTP func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error { query := ` UPDATE otps SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ? WHERE id = ? AND user_id = ? ` otp.UpdatedAt = time.Now() result, err := r.db.ExecContext( ctx, query, otp.Name, otp.Issuer, otp.Algorithm, otp.Digits, otp.Period, otp.UpdatedAt, otp.ID, otp.UserID, ) if err != nil { return fmt.Errorf("failed to update otp: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { return fmt.Errorf("otp not found or not owned by user") } return nil } // Delete deletes an OTP func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error { query := `DELETE FROM otps WHERE id = ? AND user_id = ?` result, err := r.db.ExecContext(ctx, query, id, userID) if err != nil { return fmt.Errorf("failed to delete otp: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { return fmt.Errorf("otp not found or not owned by user") } return nil } // CountByUserID counts the number of OTPs for a user func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) { var count int query := `SELECT COUNT(*) FROM otps WHERE user_id = ?` err := r.db.GetContext(ctx, &count, query, userID) if err != nil { return 0, fmt.Errorf("failed to count otps: %w", err) } return count, nil } // Transaction executes a function within a transaction func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error { tx, err := r.db.BeginTxx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if p := recover(); p != nil { tx.Rollback() panic(p) } }() if err := fn(tx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr) } return err } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil }