otpm/database/migration.go
“xHuPo” bcd986e3f7 beta
2025-05-23 18:57:11 +08:00

160 lines
3.7 KiB
Go

package database
import (
"context"
"fmt"
"log"
"time"
_ "embed"
"github.com/jmoiron/sqlx"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
// Migration represents a database migration
type Migration struct {
Name string
SQL string
Version int
}
// Migrations is a list of all migrations
var Migrations = []Migration{
{
Name: "Create users table",
SQL: userTable,
Version: 1,
},
{
Name: "Create OTP table",
SQL: otpTable,
Version: 2,
},
}
// MigrationRecord represents a record in the migrations table
type MigrationRecord struct {
ID int `db:"id"`
Version int `db:"version"`
Name string `db:"name"`
AppliedAt time.Time `db:"applied_at"`
}
// ensureMigrationsTable ensures that the migrations table exists
func ensureMigrationsTable(db *sqlx.DB) error {
query := `
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version INTEGER NOT NULL,
name TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
return nil
}
// getAppliedMigrations gets all applied migrations
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
var records []MigrationRecord
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
result := make(map[int]MigrationRecord)
for _, record := range records {
result[record.Version] = record
}
return result, nil
}
// Migrate runs all pending migrations
func Migrate(db *sqlx.DB, skipMigration bool) error {
if skipMigration {
log.Println("Skipping database migration as configured")
return nil
}
// Ensure migrations table exists
if err := ensureMigrationsTable(db); err != nil {
return err
}
// Get applied migrations
appliedMigrations, err := getAppliedMigrations(db)
if err != nil {
return err
}
// Apply pending migrations
for _, migration := range Migrations {
if _, ok := appliedMigrations[migration.Version]; ok {
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
continue
}
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
// Start a transaction for this migration
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// Execute migration
if _, err := tx.Exec(migration.SQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Record migration
if _, err := tx.Exec(
"INSERT INTO migrations (version, name) VALUES (?, ?)",
migration.Version, migration.Name,
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
}
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
}
return nil
}
// MigrateWithContext runs all pending migrations with context
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
// Create a channel to signal completion
done := make(chan error, 1)
// Run migration in a goroutine
go func() {
done <- Migrate(db, skipMigration)
}()
// Wait for migration to complete or context to be canceled
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("migration canceled: %w", ctx.Err())
}
}