package db import ( "database/sql" "errors" "time" ) // 令牌类型常量 const ( TokenTypeHOTP = "hotp" TokenTypeTOTP = "totp" ) // Token 表示一个 OTP 令牌 type Token struct { ID string `json:"id"` UserID string `json:"user_id"` Issuer string `json:"issuer"` Account string `json:"account"` Secret string `json:"secret"` Type string `json:"type"` Counter *int `json:"counter,omitempty"` Period int `json:"period"` Digits int `json:"digits"` Algorithm string `json:"algorithm"` Timestamp int64 `json:"timestamp"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } var ( ErrTokenNotFound = errors.New("token not found") ErrInvalidToken = errors.New("invalid token") ) // SaveTokens 保存用户的令牌列表 func (db *DB) SaveTokens(userID string, tokens []Token) error { tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() // 准备插入语句 stmt, err := tx.Prepare(` INSERT INTO tokens ( id, user_id, issuer, account, secret, type, counter, period, digits, algorithm, timestamp, created_at, updated_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 ) ON CONFLICT (id, user_id) DO UPDATE SET issuer = EXCLUDED.issuer, account = EXCLUDED.account, secret = EXCLUDED.secret, type = EXCLUDED.type, counter = EXCLUDED.counter, period = EXCLUDED.period, digits = EXCLUDED.digits, algorithm = EXCLUDED.algorithm, timestamp = EXCLUDED.timestamp, updated_at = EXCLUDED.updated_at `) if err != nil { return err } defer stmt.Close() now := time.Now() for _, token := range tokens { // 确保algorithm字段有默认值 algorithm := token.Algorithm if algorithm == "" { algorithm = "SHA1" } _, err = stmt.Exec( token.ID, userID, token.Issuer, token.Account, token.Secret, token.Type, token.Counter, token.Period, token.Digits, algorithm, token.Timestamp, now, now, ) if err != nil { return err } } return tx.Commit() } // GetTokensByUserID 获取用户的所有令牌 func (db *DB) GetTokensByUserID(userID string) ([]Token, error) { rows, err := db.Query(` SELECT id, user_id, issuer, account, secret, type, counter, period, digits, algorithm, timestamp, created_at, updated_at FROM tokens WHERE user_id = $1 ORDER BY timestamp DESC `, userID) if err != nil { return nil, err } defer rows.Close() var tokens []Token for rows.Next() { var token Token err := rows.Scan( &token.ID, &token.UserID, &token.Issuer, &token.Account, &token.Secret, &token.Type, &token.Counter, &token.Period, &token.Digits, &token.Algorithm, &token.Timestamp, &token.CreatedAt, &token.UpdatedAt, ) if err != nil { return nil, err } tokens = append(tokens, token) } if err = rows.Err(); err != nil { return nil, err } return tokens, nil } // DeleteToken 删除指定的令牌 func (db *DB) DeleteToken(userID, tokenID string) error { result, err := db.Exec(` DELETE FROM tokens WHERE user_id = $1 AND id = $2 `, userID, tokenID) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return ErrTokenNotFound } return nil } // UpdateTokenCounter 更新 HOTP 令牌的计数器 func (db *DB) UpdateTokenCounter(userID, tokenID string, counter int) error { result, err := db.Exec(` UPDATE tokens SET counter = $1, updated_at = $2 WHERE user_id = $3 AND id = $4 AND type = $5 `, counter, time.Now(), userID, tokenID, TokenTypeHOTP) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return ErrTokenNotFound } return nil } // GetTokenByID 获取指定的令牌 func (db *DB) GetTokenByID(userID, tokenID string) (*Token, error) { var token Token err := db.QueryRow(` SELECT id, user_id, issuer, account, secret, type, counter, period, digits, algorithm, timestamp, created_at, updated_at FROM tokens WHERE user_id = $1 AND id = $2 `, userID, tokenID).Scan( &token.ID, &token.UserID, &token.Issuer, &token.Account, &token.Secret, &token.Type, &token.Counter, &token.Period, &token.Digits, &token.Algorithm, &token.Timestamp, &token.CreatedAt, &token.UpdatedAt, ) if err == sql.ErrNoRows { return nil, ErrTokenNotFound } if err != nil { return nil, err } return &token, nil }