fix api
This commit is contained in:
parent
01b8951dd5
commit
10ebc59ffb
17 changed files with 1087 additions and 238 deletions
52
db/db.go
Normal file
52
db/db.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
Driver string
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
}
|
||||
|
||||
// DB 封装数据库操作
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// New 创建数据库连接
|
||||
func New(cfg Config) (*DB, error) {
|
||||
var dsn string
|
||||
switch cfg.Driver {
|
||||
case "postgres":
|
||||
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode)
|
||||
case "sqlite3":
|
||||
dsn = cfg.DBName
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
|
||||
}
|
||||
|
||||
db, err := sql.Open(cfg.Driver, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
return &DB{db}, nil
|
||||
}
|
224
db/token.go
Normal file
224
db/token.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 令牌类型常量
|
||||
const (
|
||||
TokenTypeHOTP = "hotp"
|
||||
TokenTypeTOTP = "totp"
|
||||
)
|
||||
|
||||
// Token 表示一个 OTP 令牌
|
||||
type Token struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Issuer string `json:"issuer"`
|
||||
Account string `json:"account"`
|
||||
Secret string `json:"secret"`
|
||||
Type string `json:"type"`
|
||||
Counter *int `json:"counter,omitempty"`
|
||||
Period int `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// SaveTokens 保存用户的令牌列表
|
||||
func (db *DB) SaveTokens(userID string, tokens []Token) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 准备插入语句
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO tokens (
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
|
||||
) ON CONFLICT (id, user_id) DO UPDATE SET
|
||||
issuer = EXCLUDED.issuer,
|
||||
account = EXCLUDED.account,
|
||||
secret = EXCLUDED.secret,
|
||||
type = EXCLUDED.type,
|
||||
counter = EXCLUDED.counter,
|
||||
period = EXCLUDED.period,
|
||||
digits = EXCLUDED.digits,
|
||||
algorithm = EXCLUDED.algorithm,
|
||||
timestamp = EXCLUDED.timestamp,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
now := time.Now()
|
||||
for _, token := range tokens {
|
||||
// 确保algorithm字段有默认值
|
||||
algorithm := token.Algorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA1"
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(
|
||||
token.ID,
|
||||
userID,
|
||||
token.Issuer,
|
||||
token.Account,
|
||||
token.Secret,
|
||||
token.Type,
|
||||
token.Counter,
|
||||
token.Period,
|
||||
token.Digits,
|
||||
algorithm,
|
||||
token.Timestamp,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetTokensByUserID 获取用户的所有令牌
|
||||
func (db *DB) GetTokensByUserID(userID string) ([]Token, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1
|
||||
ORDER BY timestamp DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []Token
|
||||
for rows.Next() {
|
||||
var token Token
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.UserID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&token.Secret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algorithm,
|
||||
&token.Timestamp,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// DeleteToken 删除指定的令牌
|
||||
func (db *DB) DeleteToken(userID, tokenID string) error {
|
||||
result, err := db.Exec(`
|
||||
DELETE FROM tokens
|
||||
WHERE user_id = $1 AND id = $2
|
||||
`, userID, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTokenCounter 更新 HOTP 令牌的计数器
|
||||
func (db *DB) UpdateTokenCounter(userID, tokenID string, counter int) error {
|
||||
result, err := db.Exec(`
|
||||
UPDATE tokens
|
||||
SET counter = $1, updated_at = $2
|
||||
WHERE user_id = $3 AND id = $4 AND type = $5
|
||||
`, counter, time.Now(), userID, tokenID, TokenTypeHOTP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenByID 获取指定的令牌
|
||||
func (db *DB) GetTokenByID(userID, tokenID string) (*Token, error) {
|
||||
var token Token
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1 AND id = $2
|
||||
`, userID, tokenID).Scan(
|
||||
&token.ID,
|
||||
&token.UserID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&token.Secret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algorithm,
|
||||
&token.Timestamp,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue