package database import ( "context" "fmt" "log" "time" _ "embed" "github.com/jmoiron/sqlx" ) var ( //go:embed init/users.sql userTable string //go:embed init/otp.sql otpTable string ) // Migration represents a database migration type Migration struct { Name string SQL string Version int } // Migrations is a list of all migrations var Migrations = []Migration{ { Name: "Create users table", SQL: userTable, Version: 1, }, { Name: "Create OTP table", SQL: otpTable, Version: 2, }, } // MigrationRecord represents a record in the migrations table type MigrationRecord struct { ID int `db:"id"` Version int `db:"version"` Name string `db:"name"` AppliedAt time.Time `db:"applied_at"` } // ensureMigrationsTable ensures that the migrations table exists func ensureMigrationsTable(db *sqlx.DB) error { query := ` CREATE TABLE IF NOT EXISTS migrations ( id INTEGER PRIMARY KEY AUTOINCREMENT, version INTEGER NOT NULL, name TEXT NOT NULL, applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); ` _, err := db.Exec(query) if err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } return nil } // getAppliedMigrations gets all applied migrations func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) { var records []MigrationRecord if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil { return nil, fmt.Errorf("failed to get applied migrations: %w", err) } result := make(map[int]MigrationRecord) for _, record := range records { result[record.Version] = record } return result, nil } // Migrate runs all pending migrations func Migrate(db *sqlx.DB, skipMigration bool) error { if skipMigration { log.Println("Skipping database migration as configured") return nil } // Ensure migrations table exists if err := ensureMigrationsTable(db); err != nil { return err } // Get applied migrations appliedMigrations, err := getAppliedMigrations(db) if err != nil { return err } // Apply pending migrations for _, migration := range Migrations { if _, ok := appliedMigrations[migration.Version]; ok { log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name) continue } log.Printf("Applying migration %d: %s", migration.Version, migration.Name) // Start a transaction for this migration tx, err := db.Beginx() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } // Execute migration if _, err := tx.Exec(migration.SQL); err != nil { tx.Rollback() return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err) } // Record migration if _, err := tx.Exec( "INSERT INTO migrations (version, name) VALUES (?, ?)", migration.Version, migration.Name, ); err != nil { tx.Rollback() return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err) } // Commit transaction if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err) } log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name) } return nil } // MigrateWithContext runs all pending migrations with context func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error { // Create a channel to signal completion done := make(chan error, 1) // Run migration in a goroutine go func() { done <- Migrate(db, skipMigration) }() // Wait for migration to complete or context to be canceled select { case err := <-done: return err case <-ctx.Done(): return fmt.Errorf("migration canceled: %w", ctx.Err()) } }