52 lines
1,000 B
Go
52 lines
1,000 B
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
_ "github.com/lib/pq"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
// Config 数据库配置
|
|
type Config struct {
|
|
Driver string
|
|
Host string
|
|
Port int
|
|
User string
|
|
Password string
|
|
DBName string
|
|
SSLMode string
|
|
}
|
|
|
|
// DB 封装数据库操作
|
|
type DB struct {
|
|
*sql.DB
|
|
}
|
|
|
|
// New 创建数据库连接
|
|
func New(cfg Config) (*DB, error) {
|
|
var dsn string
|
|
switch cfg.Driver {
|
|
case "postgres":
|
|
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
|
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode)
|
|
case "sqlite3":
|
|
dsn = cfg.DBName
|
|
default:
|
|
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
|
|
}
|
|
|
|
db, err := sql.Open(cfg.Driver, dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database: %v", err)
|
|
}
|
|
|
|
// 测试连接
|
|
if err := db.Ping(); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("failed to ping database: %v", err)
|
|
}
|
|
|
|
return &DB{db}, nil
|
|
}
|