160 lines
3.7 KiB
Go
160 lines
3.7 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
_ "embed"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
var (
|
|
//go:embed init/users.sql
|
|
userTable string
|
|
//go:embed init/otp.sql
|
|
otpTable string
|
|
)
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Name string
|
|
SQL string
|
|
Version int
|
|
}
|
|
|
|
// Migrations is a list of all migrations
|
|
var Migrations = []Migration{
|
|
{
|
|
Name: "Create users table",
|
|
SQL: userTable,
|
|
Version: 1,
|
|
},
|
|
{
|
|
Name: "Create OTP table",
|
|
SQL: otpTable,
|
|
Version: 2,
|
|
},
|
|
}
|
|
|
|
// MigrationRecord represents a record in the migrations table
|
|
type MigrationRecord struct {
|
|
ID int `db:"id"`
|
|
Version int `db:"version"`
|
|
Name string `db:"name"`
|
|
AppliedAt time.Time `db:"applied_at"`
|
|
}
|
|
|
|
// ensureMigrationsTable ensures that the migrations table exists
|
|
func ensureMigrationsTable(db *sqlx.DB) error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS migrations (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
version INTEGER NOT NULL,
|
|
name TEXT NOT NULL,
|
|
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
`
|
|
|
|
_, err := db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getAppliedMigrations gets all applied migrations
|
|
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
|
|
var records []MigrationRecord
|
|
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
|
|
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
result := make(map[int]MigrationRecord)
|
|
for _, record := range records {
|
|
result[record.Version] = record
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Migrate runs all pending migrations
|
|
func Migrate(db *sqlx.DB, skipMigration bool) error {
|
|
if skipMigration {
|
|
log.Println("Skipping database migration as configured")
|
|
return nil
|
|
}
|
|
|
|
// Ensure migrations table exists
|
|
if err := ensureMigrationsTable(db); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get applied migrations
|
|
appliedMigrations, err := getAppliedMigrations(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply pending migrations
|
|
for _, migration := range Migrations {
|
|
if _, ok := appliedMigrations[migration.Version]; ok {
|
|
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
|
|
continue
|
|
}
|
|
|
|
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
|
|
|
|
// Start a transaction for this migration
|
|
tx, err := db.Beginx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
// Execute migration
|
|
if _, err := tx.Exec(migration.SQL); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
|
|
}
|
|
|
|
// Record migration
|
|
if _, err := tx.Exec(
|
|
"INSERT INTO migrations (version, name) VALUES (?, ?)",
|
|
migration.Version, migration.Name,
|
|
); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
|
|
}
|
|
|
|
// Commit transaction
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
|
|
}
|
|
|
|
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MigrateWithContext runs all pending migrations with context
|
|
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
|
|
// Create a channel to signal completion
|
|
done := make(chan error, 1)
|
|
|
|
// Run migration in a goroutine
|
|
go func() {
|
|
done <- Migrate(db, skipMigration)
|
|
}()
|
|
|
|
// Wait for migration to complete or context to be canceled
|
|
select {
|
|
case err := <-done:
|
|
return err
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("migration canceled: %w", ctx.Err())
|
|
}
|
|
}
|