fix api
This commit is contained in:
parent
01b8951dd5
commit
10ebc59ffb
17 changed files with 1087 additions and 238 deletions
224
db/token.go
Normal file
224
db/token.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 令牌类型常量
|
||||
const (
|
||||
TokenTypeHOTP = "hotp"
|
||||
TokenTypeTOTP = "totp"
|
||||
)
|
||||
|
||||
// Token 表示一个 OTP 令牌
|
||||
type Token struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Issuer string `json:"issuer"`
|
||||
Account string `json:"account"`
|
||||
Secret string `json:"secret"`
|
||||
Type string `json:"type"`
|
||||
Counter *int `json:"counter,omitempty"`
|
||||
Period int `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// SaveTokens 保存用户的令牌列表
|
||||
func (db *DB) SaveTokens(userID string, tokens []Token) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 准备插入语句
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO tokens (
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
|
||||
) ON CONFLICT (id, user_id) DO UPDATE SET
|
||||
issuer = EXCLUDED.issuer,
|
||||
account = EXCLUDED.account,
|
||||
secret = EXCLUDED.secret,
|
||||
type = EXCLUDED.type,
|
||||
counter = EXCLUDED.counter,
|
||||
period = EXCLUDED.period,
|
||||
digits = EXCLUDED.digits,
|
||||
algorithm = EXCLUDED.algorithm,
|
||||
timestamp = EXCLUDED.timestamp,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
now := time.Now()
|
||||
for _, token := range tokens {
|
||||
// 确保algorithm字段有默认值
|
||||
algorithm := token.Algorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA1"
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(
|
||||
token.ID,
|
||||
userID,
|
||||
token.Issuer,
|
||||
token.Account,
|
||||
token.Secret,
|
||||
token.Type,
|
||||
token.Counter,
|
||||
token.Period,
|
||||
token.Digits,
|
||||
algorithm,
|
||||
token.Timestamp,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetTokensByUserID 获取用户的所有令牌
|
||||
func (db *DB) GetTokensByUserID(userID string) ([]Token, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1
|
||||
ORDER BY timestamp DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []Token
|
||||
for rows.Next() {
|
||||
var token Token
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.UserID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&token.Secret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algorithm,
|
||||
&token.Timestamp,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// DeleteToken 删除指定的令牌
|
||||
func (db *DB) DeleteToken(userID, tokenID string) error {
|
||||
result, err := db.Exec(`
|
||||
DELETE FROM tokens
|
||||
WHERE user_id = $1 AND id = $2
|
||||
`, userID, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTokenCounter 更新 HOTP 令牌的计数器
|
||||
func (db *DB) UpdateTokenCounter(userID, tokenID string, counter int) error {
|
||||
result, err := db.Exec(`
|
||||
UPDATE tokens
|
||||
SET counter = $1, updated_at = $2
|
||||
WHERE user_id = $3 AND id = $4 AND type = $5
|
||||
`, counter, time.Now(), userID, tokenID, TokenTypeHOTP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenByID 获取指定的令牌
|
||||
func (db *DB) GetTokenByID(userID, tokenID string) (*Token, error) {
|
||||
var token Token
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, user_id, issuer, account, secret, type, counter,
|
||||
period, digits, algorithm, timestamp, created_at, updated_at
|
||||
FROM tokens
|
||||
WHERE user_id = $1 AND id = $2
|
||||
`, userID, tokenID).Scan(
|
||||
&token.ID,
|
||||
&token.UserID,
|
||||
&token.Issuer,
|
||||
&token.Account,
|
||||
&token.Secret,
|
||||
&token.Type,
|
||||
&token.Counter,
|
||||
&token.Period,
|
||||
&token.Digits,
|
||||
&token.Algorithm,
|
||||
&token.Timestamp,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue