beta
This commit is contained in:
parent
a45ddf13d5
commit
bcd986e3f7
46 changed files with 6166 additions and 454 deletions
147
handlers/auth_handler.go
Normal file
147
handlers/auth_handler.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"otpm/api"
|
||||
"otpm/services"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication related requests
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a login response
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
OpenID string `json:"openid"`
|
||||
}
|
||||
|
||||
// Login handles WeChat login
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Limit request body size to prevent DOS
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
|
||||
|
||||
// Parse request
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
fmt.Sprintf("Invalid request body: %v", err))
|
||||
log.Printf("Login request parse error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Code == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Code is required")
|
||||
log.Printf("Login request validation failed: empty code")
|
||||
return
|
||||
}
|
||||
|
||||
// Login with WeChat code
|
||||
token, err := h.authService.LoginWithWeChatCode(r.Context(), req.Code)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
log.Printf("Login failed for code %s: %v", req.Code, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful login
|
||||
log.Printf("Login successful for code %s (took %v)",
|
||||
req.Code, time.Since(start))
|
||||
|
||||
// Return token
|
||||
api.NewResponseWriter(w).WriteSuccess(LoginResponse{
|
||||
Token: token,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyToken handles token verification
|
||||
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Get token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Authorization header is required")
|
||||
log.Printf("Token verification failed: missing Authorization header")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid token format. Expected 'Bearer <token>'")
|
||||
log.Printf("Token verification failed: invalid token format")
|
||||
return
|
||||
}
|
||||
|
||||
token := authHeader[7:]
|
||||
if len(token) < 32 { // Basic length check
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid token length")
|
||||
log.Printf("Token verification failed: token too short")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token
|
||||
claims, err := h.authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("Token verification failed for token %s: %v",
|
||||
maskToken(token), err) // Mask token in logs
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful verification
|
||||
userID, ok := claims.Claims.(jwt.MapClaims)["user_id"].(string)
|
||||
if !ok {
|
||||
log.Printf("Token verified but user_id claim is invalid (took %v)", time.Since(start))
|
||||
} else {
|
||||
log.Printf("Token verified for user %s (took %v)", userID, time.Since(start))
|
||||
}
|
||||
|
||||
// Token is valid
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
|
||||
"valid": true,
|
||||
})
|
||||
}
|
||||
|
||||
// maskToken masks sensitive parts of token for logging
|
||||
func maskToken(token string) string {
|
||||
if len(token) < 8 {
|
||||
return "****"
|
||||
}
|
||||
return token[:4] + "****" + token[len(token)-4:]
|
||||
}
|
||||
|
||||
// Routes returns all routes for the auth handler
|
||||
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/login": h.Login,
|
||||
"/verify-token": h.VerifyToken,
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, data interface{}, code int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
func WriteError(w http.ResponseWriter, message string, code int) {
|
||||
WriteJSON(w, Response{Code: code, Message: message}, code)
|
||||
}
|
|
@ -1,158 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// 封装code2session接口返回数据
|
||||
type LoginResponse struct {
|
||||
OpenId string `json:"openid"`
|
||||
SessionKey string `json:"session_key"`
|
||||
UnionId string `json:"unionid"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
var wxClient = &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConnsPerHost: 10,
|
||||
},
|
||||
}
|
||||
|
||||
func getLoginResponse(code string) (*LoginResponse, error) {
|
||||
appid := viper.GetString("wechat.appid")
|
||||
secret := viper.GetString("wechat.secret")
|
||||
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", appid, secret, code)
|
||||
resp, err := wxClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var loginResponse LoginResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&loginResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if loginResponse.ErrCode != 0 {
|
||||
switch loginResponse.ErrCode {
|
||||
case 40029:
|
||||
return nil, fmt.Errorf("invalid code: %s", loginResponse.ErrMsg)
|
||||
case 45011:
|
||||
return nil, fmt.Errorf("api limit exceeded: %s", loginResponse.ErrMsg)
|
||||
default:
|
||||
return nil, fmt.Errorf("wechat login error: %s", loginResponse.ErrMsg)
|
||||
}
|
||||
}
|
||||
return &loginResponse, nil
|
||||
}
|
||||
|
||||
func generateJWT(openid string) (string, error) {
|
||||
tokenTTL := viper.GetDuration("auth.ttl")
|
||||
if tokenTTL <= 0 {
|
||||
tokenTTL = 24 * time.Hour
|
||||
}
|
||||
|
||||
secret := viper.GetString("auth.secret")
|
||||
if secret == "" {
|
||||
secret = "default_auth_secret_otpm"
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"openid": openid,
|
||||
"exp": time.Now().Add(tokenTTL).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"iss": viper.GetString("server.name"),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req LoginRequest
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loginResponse, err := getLoginResponse(req.Code)
|
||||
if err != nil {
|
||||
switch {
|
||||
case err.Error() == "invalid code":
|
||||
WriteError(w, "Invalid code", http.StatusUnauthorized)
|
||||
case err.Error() == "api limit exceeded":
|
||||
WriteError(w, "API rate limit exceeded", http.StatusTooManyRequests)
|
||||
default:
|
||||
WriteError(w, "Failed to get login response", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 插入或更新用户的openid和session_key
|
||||
query := `
|
||||
INSERT INTO users (openid, session_key)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (openid) DO UPDATE SET session_key = $2
|
||||
RETURNING id;
|
||||
`
|
||||
|
||||
var ID int
|
||||
if err := h.DB.QueryRow(query, loginResponse.OpenId, loginResponse.SessionKey).Scan(&ID); err != nil {
|
||||
WriteError(w, "Failed to log in user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := generateJWT(loginResponse.OpenId)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"t": token,
|
||||
"openid": loginResponse.OpenId,
|
||||
}
|
||||
|
||||
WriteJSON(w, Response{Code: 0, Message: "Success", Data: data}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *Handler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
userid := r.Context().Value("openid").(string)
|
||||
|
||||
token, err := generateJWT(userid)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to generate JWT token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
WriteJSON(w, Response{
|
||||
Code: 0,
|
||||
Message: "Token refreshed successfully",
|
||||
Data: map[string]string{
|
||||
"token": token,
|
||||
},
|
||||
}, http.StatusOK)
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OtpRequest struct {
|
||||
OpenID string `json:"openid"`
|
||||
Token *[]OTP `json:"token"`
|
||||
}
|
||||
type OTP struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Remark string `json:"remark"`
|
||||
Secret string `json:"secret"`
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateOrCreateOtp(w http.ResponseWriter, r *http.Request) {
|
||||
var req OtpRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
WriteError(w, "Failed to parse request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.OpenID == "" {
|
||||
WriteError(w, "OpenID is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == nil || len(*req.Token) == 0 {
|
||||
WriteError(w, "Token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Saving OTP for user: %s token count:: %d", req.OpenID, len(*req.Token))
|
||||
|
||||
// 插入或更新 OTP 记录
|
||||
query := `
|
||||
INSERT INTO otp (openid, token)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (openid) DO UPDATE SET token = EXCLUDED.token
|
||||
`
|
||||
|
||||
_, err := h.DB.Exec(query, req.OpenID, req.Token)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to update or create OTP", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
WriteJSON(w, Response{Code: 0, Message: "Success"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *Handler) GetOtp(w http.ResponseWriter, r *http.Request) {
|
||||
openid := r.URL.Query().Get("openid")
|
||||
if openid == "" {
|
||||
WriteError(w, "OpenID is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var otp OtpRequest
|
||||
|
||||
err := h.DB.Get(&otp, "SELECT token FROM otp WHERE openid=$1", openid)
|
||||
if err != nil {
|
||||
WriteError(w, "Failed to get OTP", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
WriteJSON(w, Response{Code: 0, Message: "Success", Data: otp}, http.StatusOK)
|
||||
}
|
286
handlers/otp_handler.go
Normal file
286
handlers/otp_handler.go
Normal file
|
@ -0,0 +1,286 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"otpm/api"
|
||||
"otpm/middleware"
|
||||
"otpm/models"
|
||||
"otpm/services"
|
||||
)
|
||||
|
||||
// OTPHandler handles OTP related requests
|
||||
type OTPHandler struct {
|
||||
otpService *services.OTPService
|
||||
}
|
||||
|
||||
// NewOTPHandler creates a new OTPHandler
|
||||
func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
|
||||
return &OTPHandler{
|
||||
otpService: otpService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOTPRequest represents a request to create an OTP
|
||||
type CreateOTPRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
Secret string `json:"secret"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
}
|
||||
|
||||
// CreateOTP handles OTP creation
|
||||
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Limit request body size
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
|
||||
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("CreateOTP unauthorized attempt")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request
|
||||
var req CreateOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
fmt.Sprintf("Invalid request body: %v", err))
|
||||
log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate OTP parameters
|
||||
if req.Secret == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Secret is required")
|
||||
log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate algorithm
|
||||
supportedAlgos := map[string]bool{
|
||||
"SHA1": true,
|
||||
"SHA256": true,
|
||||
"SHA512": true,
|
||||
}
|
||||
if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
|
||||
log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
|
||||
userID, req.Algorithm)
|
||||
return
|
||||
}
|
||||
|
||||
// Create OTP
|
||||
otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
Secret: req.Secret,
|
||||
Algorithm: req.Algorithm,
|
||||
Digits: req.Digits,
|
||||
Period: req.Period,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
|
||||
log.Printf("CreateOTP failed for user %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful creation (mask secret in logs)
|
||||
log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
|
||||
userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(otp)
|
||||
}
|
||||
|
||||
// ListOTPs handles listing all OTPs for a user
|
||||
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTPs
|
||||
otps, err := h.otpService.ListOTPs(r.Context(), userID)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(otps)
|
||||
}
|
||||
|
||||
// GetOTPCode handles generating OTP code
|
||||
func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
otpID = strings.TrimSuffix(otpID, "/code")
|
||||
|
||||
// Validate OTP ID format
|
||||
if len(otpID) != 36 { // Assuming UUID format
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid OTP ID format")
|
||||
log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting check could be added here
|
||||
// (would require redis or similar rate limiter)
|
||||
|
||||
// Generate code
|
||||
code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful generation (without actual code)
|
||||
log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
|
||||
userID, otpID, time.Since(start), expiresIn)
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
|
||||
"code": code,
|
||||
"expires_in": expiresIn,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyOTPRequest represents a request to verify an OTP code
|
||||
type VerifyOTPRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// VerifyOTP handles OTP code verification
|
||||
func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
otpID = strings.TrimSuffix(otpID, "/verify")
|
||||
|
||||
// Parse request
|
||||
var req VerifyOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify code
|
||||
valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
|
||||
"valid": valid,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOTPRequest represents a request to update an OTP
|
||||
type UpdateOTPRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
}
|
||||
|
||||
// UpdateOTP handles OTP update
|
||||
func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
|
||||
// Parse request
|
||||
var req UpdateOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Update OTP
|
||||
otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
Algorithm: req.Algorithm,
|
||||
Digits: req.Digits,
|
||||
Period: req.Period,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(otp)
|
||||
}
|
||||
|
||||
// DeleteOTP handles OTP deletion
|
||||
func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
|
||||
// Delete OTP
|
||||
if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]string{
|
||||
"message": "OTP deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Routes returns all routes for the OTP handler
|
||||
func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/otp": h.CreateOTP,
|
||||
"/otp/": h.ListOTPs,
|
||||
"/otp/{id}": h.UpdateOTP,
|
||||
"/otp/{id}/code": h.GetOTPCode,
|
||||
"/otp/{id}/verify": h.VerifyOTP,
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue