64 lines
1.4 KiB
Go
64 lines
1.4 KiB
Go
package database
|
|
|
|
import (
|
|
_ "embed"
|
|
"fmt"
|
|
"log"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/lib/pq"
|
|
"github.com/spf13/viper"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
var (
|
|
//go:embed init/users.sql
|
|
userTable string
|
|
//go:embed init/otp.sql
|
|
otpTable string
|
|
)
|
|
|
|
func InitDB() (*sqlx.DB, error) {
|
|
driver := viper.GetString("database.driver")
|
|
dsn := viper.GetString("database.dsn")
|
|
|
|
db, err := sqlx.Open(driver, dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
log.Println("Connected to database!")
|
|
return db, nil
|
|
}
|
|
|
|
func MigrateDB(db *sqlx.DB) error {
|
|
// 检查是否需要执行迁移
|
|
skipMigration := viper.GetBool("database.skip_migration")
|
|
if skipMigration {
|
|
log.Println("Skipping database migration as configured")
|
|
return nil
|
|
}
|
|
|
|
// 执行用户表迁移
|
|
if _, err := db.Exec(userTable); err != nil {
|
|
log.Printf("Warning: failed to create user migration: %v", err)
|
|
// 继续执行,不返回错误
|
|
} else {
|
|
log.Println("User table migration completed successfully")
|
|
}
|
|
|
|
// 执行OTP表迁移
|
|
if _, err := db.Exec(otpTable); err != nil {
|
|
log.Printf("Warning: failed to create otp migration: %v", err)
|
|
// 继续执行,不返回错误
|
|
} else {
|
|
log.Println("OTP table migration completed successfully")
|
|
}
|
|
|
|
return nil
|
|
}
|