224 lines
4.5 KiB
Go
224 lines
4.5 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"time"
|
|
)
|
|
|
|
// 令牌类型常量
|
|
const (
|
|
TokenTypeHOTP = "hotp"
|
|
TokenTypeTOTP = "totp"
|
|
)
|
|
|
|
// Token 表示一个 OTP 令牌
|
|
type Token struct {
|
|
ID string `json:"id"`
|
|
UserID string `json:"user_id"`
|
|
Issuer string `json:"issuer"`
|
|
Account string `json:"account"`
|
|
Secret string `json:"secret"`
|
|
Type string `json:"type"`
|
|
Counter *int `json:"counter,omitempty"`
|
|
Period int `json:"period"`
|
|
Digits int `json:"digits"`
|
|
Algorithm string `json:"algorithm"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
var (
|
|
ErrTokenNotFound = errors.New("token not found")
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
)
|
|
|
|
// SaveTokens 保存用户的令牌列表
|
|
func (db *DB) SaveTokens(userID string, tokens []Token) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// 准备插入语句
|
|
stmt, err := tx.Prepare(`
|
|
INSERT INTO tokens (
|
|
id, user_id, issuer, account, secret, type, counter,
|
|
period, digits, algorithm, timestamp, created_at, updated_at
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
|
|
) ON CONFLICT (id, user_id) DO UPDATE SET
|
|
issuer = EXCLUDED.issuer,
|
|
account = EXCLUDED.account,
|
|
secret = EXCLUDED.secret,
|
|
type = EXCLUDED.type,
|
|
counter = EXCLUDED.counter,
|
|
period = EXCLUDED.period,
|
|
digits = EXCLUDED.digits,
|
|
algorithm = EXCLUDED.algorithm,
|
|
timestamp = EXCLUDED.timestamp,
|
|
updated_at = EXCLUDED.updated_at
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
now := time.Now()
|
|
for _, token := range tokens {
|
|
// 确保algorithm字段有默认值
|
|
algorithm := token.Algorithm
|
|
if algorithm == "" {
|
|
algorithm = "SHA1"
|
|
}
|
|
|
|
_, err = stmt.Exec(
|
|
token.ID,
|
|
userID,
|
|
token.Issuer,
|
|
token.Account,
|
|
token.Secret,
|
|
token.Type,
|
|
token.Counter,
|
|
token.Period,
|
|
token.Digits,
|
|
algorithm,
|
|
token.Timestamp,
|
|
now,
|
|
now,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// GetTokensByUserID 获取用户的所有令牌
|
|
func (db *DB) GetTokensByUserID(userID string) ([]Token, error) {
|
|
rows, err := db.Query(`
|
|
SELECT
|
|
id, user_id, issuer, account, secret, type, counter,
|
|
period, digits, algorithm, timestamp, created_at, updated_at
|
|
FROM tokens
|
|
WHERE user_id = $1
|
|
ORDER BY timestamp DESC
|
|
`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tokens []Token
|
|
for rows.Next() {
|
|
var token Token
|
|
err := rows.Scan(
|
|
&token.ID,
|
|
&token.UserID,
|
|
&token.Issuer,
|
|
&token.Account,
|
|
&token.Secret,
|
|
&token.Type,
|
|
&token.Counter,
|
|
&token.Period,
|
|
&token.Digits,
|
|
&token.Algorithm,
|
|
&token.Timestamp,
|
|
&token.CreatedAt,
|
|
&token.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tokens = append(tokens, token)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// DeleteToken 删除指定的令牌
|
|
func (db *DB) DeleteToken(userID, tokenID string) error {
|
|
result, err := db.Exec(`
|
|
DELETE FROM tokens
|
|
WHERE user_id = $1 AND id = $2
|
|
`, userID, tokenID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if affected == 0 {
|
|
return ErrTokenNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateTokenCounter 更新 HOTP 令牌的计数器
|
|
func (db *DB) UpdateTokenCounter(userID, tokenID string, counter int) error {
|
|
result, err := db.Exec(`
|
|
UPDATE tokens
|
|
SET counter = $1, updated_at = $2
|
|
WHERE user_id = $3 AND id = $4 AND type = $5
|
|
`, counter, time.Now(), userID, tokenID, TokenTypeHOTP)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if affected == 0 {
|
|
return ErrTokenNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTokenByID 获取指定的令牌
|
|
func (db *DB) GetTokenByID(userID, tokenID string) (*Token, error) {
|
|
var token Token
|
|
err := db.QueryRow(`
|
|
SELECT
|
|
id, user_id, issuer, account, secret, type, counter,
|
|
period, digits, algorithm, timestamp, created_at, updated_at
|
|
FROM tokens
|
|
WHERE user_id = $1 AND id = $2
|
|
`, userID, tokenID).Scan(
|
|
&token.ID,
|
|
&token.UserID,
|
|
&token.Issuer,
|
|
&token.Account,
|
|
&token.Secret,
|
|
&token.Type,
|
|
&token.Counter,
|
|
&token.Period,
|
|
&token.Digits,
|
|
&token.Algorithm,
|
|
&token.Timestamp,
|
|
&token.CreatedAt,
|
|
&token.UpdatedAt,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrTokenNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &token, nil
|
|
}
|