196 lines
5.3 KiB
Go
196 lines
5.3 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
"time"
|
|
|
|
"otpm/config"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
// DB wraps sqlx.DB to provide additional functionality
|
|
type DB struct {
|
|
*sqlx.DB
|
|
}
|
|
|
|
// New creates a new database connection
|
|
func New(cfg *config.DatabaseConfig) (*DB, error) {
|
|
db, err := sqlx.Open(cfg.Driver, cfg.DSN)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
// Configure connection pool with optimized settings
|
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
|
|
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
|
|
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
|
|
|
|
// Verify connection with timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
log.Println("Successfully connected to database")
|
|
return &DB{db}, nil
|
|
}
|
|
|
|
// WithTx executes a function within a transaction with retry logic
|
|
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
|
|
const maxRetries = 3
|
|
var lastErr error
|
|
|
|
// Default transaction options
|
|
opts := &sql.TxOptions{
|
|
Isolation: sql.LevelReadCommitted,
|
|
}
|
|
|
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
|
start := time.Now()
|
|
|
|
tx, err := db.BeginTxx(ctx, opts)
|
|
if err != nil {
|
|
if isRetryableError(err) && attempt < maxRetries {
|
|
lastErr = err
|
|
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) // exponential backoff
|
|
continue
|
|
}
|
|
return fmt.Errorf("failed to begin transaction (attempt %d/%d): %w", attempt, maxRetries, err)
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
tx.Rollback()
|
|
panic(p)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
if rbErr := tx.Rollback(); rbErr != nil {
|
|
log.Printf("Transaction rollback error: %v (original error: %v)", rbErr, err)
|
|
}
|
|
|
|
if isRetryableError(err) && attempt < maxRetries {
|
|
lastErr = err
|
|
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
|
continue
|
|
}
|
|
return fmt.Errorf("transaction failed (attempt %d/%d): %w", attempt, maxRetries, err)
|
|
}
|
|
|
|
// Log long-running transactions
|
|
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
|
|
log.Printf("Transaction completed in %v", elapsed)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
if isRetryableError(err) && attempt < maxRetries {
|
|
lastErr = err
|
|
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
|
continue
|
|
}
|
|
return fmt.Errorf("failed to commit transaction (attempt %d/%d): %w", attempt, maxRetries, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return lastErr
|
|
}
|
|
|
|
// isRetryableError checks if an error is likely to succeed on retry
|
|
func isRetryableError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
|
|
errStr := strings.ToLower(err.Error())
|
|
return strings.Contains(errStr, "deadlock") ||
|
|
strings.Contains(errStr, "timeout") ||
|
|
strings.Contains(errStr, "try again") ||
|
|
strings.Contains(errStr, "connection reset") ||
|
|
strings.Contains(errStr, "busy") ||
|
|
strings.Contains(errStr, "locked")
|
|
}
|
|
|
|
// ExecContext executes a query with adaptive timeout
|
|
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) error {
|
|
// Set timeout based on query complexity
|
|
timeout := 5 * time.Second
|
|
if strings.Contains(strings.ToLower(query), "insert") ||
|
|
strings.Contains(strings.ToLower(query), "update") ||
|
|
strings.Contains(strings.ToLower(query), "delete") {
|
|
timeout = 10 * time.Second
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
_, err := db.DB.ExecContext(ctx, query, args...)
|
|
elapsed := time.Since(start)
|
|
|
|
// Log slow queries
|
|
if elapsed > timeout/2 {
|
|
log.Printf("Slow query execution detected: %s (took %v)", query, elapsed)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// QueryRowContext executes a query that returns a single row with timeout
|
|
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
return db.DB.QueryRowxContext(ctx, query, args...)
|
|
}
|
|
|
|
// QueryContext executes a query that returns multiple rows with adaptive timeout
|
|
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
|
|
// Set timeout based on query complexity
|
|
timeout := 5 * time.Second
|
|
if strings.Contains(strings.ToLower(query), "join") ||
|
|
strings.Contains(strings.ToLower(query), "group by") ||
|
|
strings.Contains(strings.ToLower(query), "order by") {
|
|
timeout = 15 * time.Second
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
rows, err := db.DB.QueryxContext(ctx, query, args...)
|
|
elapsed := time.Since(start)
|
|
|
|
// Log slow queries
|
|
if elapsed > timeout/2 {
|
|
log.Printf("Slow query detected: %s (took %v)", query, elapsed)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute query [%s] with args %v: %w", query, args, err)
|
|
}
|
|
|
|
return rows, nil
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (db *DB) Close() error {
|
|
if err := db.DB.Close(); err != nil {
|
|
return fmt.Errorf("failed to close database connection: %w", err)
|
|
}
|
|
return nil
|
|
}
|