beta
This commit is contained in:
parent
a45ddf13d5
commit
bcd986e3f7
46 changed files with 6166 additions and 454 deletions
|
@ -1,64 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/spf13/viper"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed init/users.sql
|
||||
userTable string
|
||||
//go:embed init/otp.sql
|
||||
otpTable string
|
||||
)
|
||||
|
||||
func InitDB() (*sqlx.DB, error) {
|
||||
driver := viper.GetString("database.driver")
|
||||
dsn := viper.GetString("database.dsn")
|
||||
|
||||
db, err := sqlx.Open(driver, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Connected to database!")
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func MigrateDB(db *sqlx.DB) error {
|
||||
// 检查是否需要执行迁移
|
||||
skipMigration := viper.GetBool("database.skip_migration")
|
||||
if skipMigration {
|
||||
log.Println("Skipping database migration as configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行用户表迁移
|
||||
if _, err := db.Exec(userTable); err != nil {
|
||||
log.Printf("Warning: failed to create user migration: %v", err)
|
||||
// 继续执行,不返回错误
|
||||
} else {
|
||||
log.Println("User table migration completed successfully")
|
||||
}
|
||||
|
||||
// 执行OTP表迁移
|
||||
if _, err := db.Exec(otpTable); err != nil {
|
||||
log.Printf("Warning: failed to create otp migration: %v", err)
|
||||
// 继续执行,不返回错误
|
||||
} else {
|
||||
log.Println("OTP table migration completed successfully")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
196
database/db.go
Normal file
196
database/db.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"otpm/config"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// DB wraps sqlx.DB to provide additional functionality
|
||||
type DB struct {
|
||||
*sqlx.DB
|
||||
}
|
||||
|
||||
// New creates a new database connection
|
||||
func New(cfg *config.DatabaseConfig) (*DB, error) {
|
||||
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool with optimized settings
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
|
||||
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
|
||||
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
|
||||
|
||||
// Verify connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully connected to database")
|
||||
return &DB{db}, nil
|
||||
}
|
||||
|
||||
// WithTx executes a function within a transaction with retry logic
|
||||
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
|
||||
const maxRetries = 3
|
||||
var lastErr error
|
||||
|
||||
// Default transaction options
|
||||
opts := &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
start := time.Now()
|
||||
|
||||
tx, err := db.BeginTxx(ctx, opts)
|
||||
if err != nil {
|
||||
if isRetryableError(err) && attempt < maxRetries {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
|
||||
}
|
||||
|
||||
if isRetryableError(err) && attempt < maxRetries {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
|
||||
}
|
||||
|
||||
// Log long-running transactions
|
||||
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
|
||||
log.Printf("Transaction completed in %v", elapsed)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
if isRetryableError(err) && attempt < maxRetries {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error is likely to succeed on retry
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "deadlock") ||
|
||||
strings.Contains(errStr, "timeout") ||
|
||||
strings.Contains(errStr, "try again") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "busy") ||
|
||||
strings.Contains(errStr, "locked")
|
||||
}
|
||||
|
||||
// ExecContext executes a query with adaptive timeout
|
||||
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
|
||||
// Set timeout based on query complexity
|
||||
timeout := 5 * time.Second
|
||||
if strings.Contains(strings.ToLower(query), "insert") ||
|
||||
strings.Contains(strings.ToLower(query), "update") ||
|
||||
strings.Contains(strings.ToLower(query), "delete") {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := db.DB.ExecContext(ctx, query, args...)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Log slow queries
|
||||
if elapsed > timeout/2 {
|
||||
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryRowContext executes a query that returns a single row with timeout
|
||||
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return db.DB.QueryRowxContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// QueryContext executes a query that returns multiple rows with adaptive timeout
|
||||
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
|
||||
// Set timeout based on query complexity
|
||||
timeout := 5 * time.Second
|
||||
if strings.Contains(strings.ToLower(query), "join") ||
|
||||
strings.Contains(strings.ToLower(query), "group by") ||
|
||||
strings.Contains(strings.ToLower(query), "order by") {
|
||||
timeout = 15 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
rows, err := db.DB.QueryxContext(ctx, query, args...)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Log slow queries
|
||||
if elapsed > timeout/2 {
|
||||
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
if err := db.DB.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close database connection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
160
database/migration.go
Normal file
160
database/migration.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed init/users.sql
|
||||
userTable string
|
||||
//go:embed init/otp.sql
|
||||
otpTable string
|
||||
)
|
||||
|
||||
// Migration represents a database migration
|
||||
type Migration struct {
|
||||
Name string
|
||||
SQL string
|
||||
Version int
|
||||
}
|
||||
|
||||
// Migrations is a list of all migrations
|
||||
var Migrations = []Migration{
|
||||
{
|
||||
Name: "Create users table",
|
||||
SQL: userTable,
|
||||
Version: 1,
|
||||
},
|
||||
{
|
||||
Name: "Create OTP table",
|
||||
SQL: otpTable,
|
||||
Version: 2,
|
||||
},
|
||||
}
|
||||
|
||||
// MigrationRecord represents a record in the migrations table
|
||||
type MigrationRecord struct {
|
||||
ID int `db:"id"`
|
||||
Version int `db:"version"`
|
||||
Name string `db:"name"`
|
||||
AppliedAt time.Time `db:"applied_at"`
|
||||
}
|
||||
|
||||
// ensureMigrationsTable ensures that the migrations table exists
|
||||
func ensureMigrationsTable(db *sqlx.DB) error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
version INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.Exec(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrations table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAppliedMigrations gets all applied migrations
|
||||
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
|
||||
var records []MigrationRecord
|
||||
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
|
||||
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[int]MigrationRecord)
|
||||
for _, record := range records {
|
||||
result[record.Version] = record
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Migrate runs all pending migrations
|
||||
func Migrate(db *sqlx.DB, skipMigration bool) error {
|
||||
if skipMigration {
|
||||
log.Println("Skipping database migration as configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure migrations table exists
|
||||
if err := ensureMigrationsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get applied migrations
|
||||
appliedMigrations, err := getAppliedMigrations(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply pending migrations
|
||||
for _, migration := range Migrations {
|
||||
if _, ok := appliedMigrations[migration.Version]; ok {
|
||||
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
|
||||
|
||||
// Start a transaction for this migration
|
||||
tx, err := db.Beginx()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
// Execute migration
|
||||
if _, err := tx.Exec(migration.SQL); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
// Record migration
|
||||
if _, err := tx.Exec(
|
||||
"INSERT INTO migrations (version, name) VALUES (?, ?)",
|
||||
migration.Version, migration.Name,
|
||||
); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateWithContext runs all pending migrations with context
|
||||
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
|
||||
// Create a channel to signal completion
|
||||
done := make(chan error, 1)
|
||||
|
||||
// Run migration in a goroutine
|
||||
go func() {
|
||||
done <- Migrate(db, skipMigration)
|
||||
}()
|
||||
|
||||
// Wait for migration to complete or context to be canceled
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("migration canceled: %w", ctx.Err())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue