This commit is contained in:
“xHuPo” 2025-05-23 18:57:11 +08:00
parent a45ddf13d5
commit bcd986e3f7
46 changed files with 6166 additions and 454 deletions

View file

@ -1,64 +0,0 @@
package database
import (
_ "embed"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/spf13/viper"
_ "modernc.org/sqlite"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
func InitDB() (*sqlx.DB, error) {
driver := viper.GetString("database.driver")
dsn := viper.GetString("database.dsn")
db, err := sqlx.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Println("Connected to database!")
return db, nil
}
func MigrateDB(db *sqlx.DB) error {
// 检查是否需要执行迁移
skipMigration := viper.GetBool("database.skip_migration")
if skipMigration {
log.Println("Skipping database migration as configured")
return nil
}
// 执行用户表迁移
if _, err := db.Exec(userTable); err != nil {
log.Printf("Warning: failed to create user migration: %v", err)
// 继续执行,不返回错误
} else {
log.Println("User table migration completed successfully")
}
// 执行OTP表迁移
if _, err := db.Exec(otpTable); err != nil {
log.Printf("Warning: failed to create otp migration: %v", err)
// 继续执行,不返回错误
} else {
log.Println("OTP table migration completed successfully")
}
return nil
}

196
database/db.go Normal file
View file

@ -0,0 +1,196 @@
package database
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"otpm/config"
"github.com/jmoiron/sqlx"
)
// DB wraps sqlx.DB to provide additional functionality
type DB struct {
*sqlx.DB
}
// New creates a new database connection
func New(cfg *config.DatabaseConfig) (*DB, error) {
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Configure connection pool with optimized settings
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
// Verify connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Println("Successfully connected to database")
return &DB{db}, nil
}
// WithTx executes a function within a transaction with retry logic
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
const maxRetries = 3
var lastErr error
// Default transaction options
opts := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
}
for attempt := 1; attempt <= maxRetries; attempt++ {
start := time.Now()
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
continue
}
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
}
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
}
// Log long-running transactions
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
log.Printf("Transaction completed in %v", elapsed)
}
if err := tx.Commit(); err != nil {
if isRetryableError(err) && attempt < maxRetries {
lastErr = err
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
}
return nil
}
return lastErr
}
// isRetryableError checks if an error is likely to succeed on retry
func isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "deadlock") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "try again") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "busy") ||
strings.Contains(errStr, "locked")
}
// ExecContext executes a query with adaptive timeout
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "insert") ||
strings.Contains(strings.ToLower(query), "update") ||
strings.Contains(strings.ToLower(query), "delete") {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
_, err := db.DB.ExecContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
}
if err != nil {
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return nil
}
// QueryRowContext executes a query that returns a single row with timeout
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return db.DB.QueryRowxContext(ctx, query, args...)
}
// QueryContext executes a query that returns multiple rows with adaptive timeout
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
// Set timeout based on query complexity
timeout := 5 * time.Second
if strings.Contains(strings.ToLower(query), "join") ||
strings.Contains(strings.ToLower(query), "group by") ||
strings.Contains(strings.ToLower(query), "order by") {
timeout = 15 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
rows, err := db.DB.QueryxContext(ctx, query, args...)
elapsed := time.Since(start)
// Log slow queries
if elapsed > timeout/2 {
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
}
if err != nil {
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
}
return rows, nil
}
// Close closes the database connection
func (db *DB) Close() error {
if err := db.DB.Close(); err != nil {
return fmt.Errorf("failed to close database connection: %w", err)
}
return nil
}

160
database/migration.go Normal file
View file

@ -0,0 +1,160 @@
package database
import (
"context"
"fmt"
"log"
"time"
_ "embed"
"github.com/jmoiron/sqlx"
)
var (
//go:embed init/users.sql
userTable string
//go:embed init/otp.sql
otpTable string
)
// Migration represents a database migration
type Migration struct {
Name string
SQL string
Version int
}
// Migrations is a list of all migrations
var Migrations = []Migration{
{
Name: "Create users table",
SQL: userTable,
Version: 1,
},
{
Name: "Create OTP table",
SQL: otpTable,
Version: 2,
},
}
// MigrationRecord represents a record in the migrations table
type MigrationRecord struct {
ID int `db:"id"`
Version int `db:"version"`
Name string `db:"name"`
AppliedAt time.Time `db:"applied_at"`
}
// ensureMigrationsTable ensures that the migrations table exists
func ensureMigrationsTable(db *sqlx.DB) error {
query := `
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version INTEGER NOT NULL,
name TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
return nil
}
// getAppliedMigrations gets all applied migrations
func getAppliedMigrations(db *sqlx.DB) (map[int]MigrationRecord, error) {
var records []MigrationRecord
if err := db.Select(&records, "SELECT * FROM migrations ORDER BY version"); err != nil {
return nil, fmt.Errorf("failed to get applied migrations: %w", err)
}
result := make(map[int]MigrationRecord)
for _, record := range records {
result[record.Version] = record
}
return result, nil
}
// Migrate runs all pending migrations
func Migrate(db *sqlx.DB, skipMigration bool) error {
if skipMigration {
log.Println("Skipping database migration as configured")
return nil
}
// Ensure migrations table exists
if err := ensureMigrationsTable(db); err != nil {
return err
}
// Get applied migrations
appliedMigrations, err := getAppliedMigrations(db)
if err != nil {
return err
}
// Apply pending migrations
for _, migration := range Migrations {
if _, ok := appliedMigrations[migration.Version]; ok {
log.Printf("Migration %d (%s) already applied", migration.Version, migration.Name)
continue
}
log.Printf("Applying migration %d: %s", migration.Version, migration.Name)
// Start a transaction for this migration
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// Execute migration
if _, err := tx.Exec(migration.SQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Record migration
if _, err := tx.Exec(
"INSERT INTO migrations (version, name) VALUES (?, ?)",
migration.Version, migration.Name,
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d (%s): %w", migration.Version, migration.Name, err)
}
log.Printf("Successfully applied migration %d: %s", migration.Version, migration.Name)
}
return nil
}
// MigrateWithContext runs all pending migrations with context
func MigrateWithContext(ctx context.Context, db *sqlx.DB, skipMigration bool) error {
// Create a channel to signal completion
done := make(chan error, 1)
// Run migration in a goroutine
go func() {
done <- Migrate(db, skipMigration)
}()
// Wait for migration to complete or context to be canceled
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("migration canceled: %w", ctx.Err())
}
}